From 2415b95428cfae6585026662541920696ea2bab5 Mon Sep 17 00:00:00 2001 From: vinland100 Date: Thu, 8 Jan 2026 17:40:53 +0800 Subject: [PATCH] Add the token bucket algorithm. --- backend/app/services/rag/embeddings.py | 50 +++++++++++++++++++++----- 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/backend/app/services/rag/embeddings.py b/backend/app/services/rag/embeddings.py index bcef650..99d048a 100644 --- a/backend/app/services/rag/embeddings.py +++ b/backend/app/services/rag/embeddings.py @@ -6,6 +6,7 @@ import asyncio import hashlib import logging +import time from typing import List, Dict, Any, Optional from abc import ABC, abstractmethod from dataclasses import dataclass @@ -664,21 +665,26 @@ class EmbeddingService: base_url=self.base_url, ) - # 🔥 控制并发请求数 (RPS 限制) + # 🔥 控制并发请求数和 RPS 限制 is_remote = self.provider.lower() in ["openai", "qwen", "azure", "cohere", "jina", "huggingface"] - self.concurrency = getattr(settings, "EMBEDDING_CONCURRENCY", 2 if is_remote else 10) + + # 设置最大并发数,与 RPS 保持一致以最大化吞吐 + self.max_rps = getattr(settings, "EMBEDDING_RPS", 30 if is_remote else 100) + self.concurrency = getattr(settings, "EMBEDDING_CONCURRENCY", self.max_rps if is_remote else 10) self._semaphore = asyncio.Semaphore(self.concurrency) + # 🔥 RPS 令牌桶限流器 + self._rps_tokens = self.max_rps # 当前可用令牌数 + self._rps_last_refill = time.monotonic() # 上次补充时间 + self._rps_lock = asyncio.Lock() # 保护令牌桶的锁 + # 🔥 设置默认批次大小 (DashScope text-embedding-v4 限制为 10) - self.batch_size = 10 if is_remote else 100 - - - + self.batch_size = getattr(settings, "EMBEDDING_BATCH_SIZE", 10 if is_remote else 100) # 🔥 共享 HTTP 客户端 self._client: Optional[httpx.AsyncClient] = None - logger.info(f"Embedding service initialized with {self.provider}/{self.model} (Concurrency: {self.concurrency}, Batch size: {self.batch_size})") + logger.info(f"Embedding service initialized with {self.provider}/{self.model} (RPS: {self.max_rps}, Concurrency: {self.concurrency}, Batch size: {self.batch_size})") async def _get_client(self) -> httpx.AsyncClient: """获取或创建共享的 AsyncClient""" @@ -849,6 +855,31 @@ class EmbeddingService: # 确保没有 None return [e if e is not None else [0.0] * self.dimension for e in embeddings] + async def _acquire_rps_token(self): + """获取 RPS 令牌(令牌桶算法)""" + async with self._rps_lock: + now = time.monotonic() + elapsed = now - self._rps_last_refill + + # 补充令牌:每秒补充 max_rps 个令牌 + self._rps_tokens = min( + self.max_rps, + self._rps_tokens + elapsed * self.max_rps + ) + self._rps_last_refill = now + + if self._rps_tokens >= 1: + self._rps_tokens -= 1 + return + + # 没有令牌,计算等待时间 + wait_time = (1 - self._rps_tokens) / self.max_rps + + # 在锁外等待 + await asyncio.sleep(wait_time) + # 递归获取令牌 + await self._acquire_rps_token() + async def _process_batch_with_retry( self, batch: List[str], @@ -856,7 +887,7 @@ class EmbeddingService: cancel_check: Optional[callable] = None, max_retries: Optional[int] = None ) -> List[EmbeddingResult]: - """带重试机制的单批次处理""" + """带重试机制和 RPS 限流的单批次处理""" # 优先使用配置中的重试次数 actual_max_retries = max_retries or getattr(settings, "EMBEDDING_RETRY_MAX", 5) @@ -866,6 +897,9 @@ class EmbeddingService: if cancel_check and cancel_check(): raise asyncio.CancelledError("嵌入操作已取消") + # 🔥 先获取 RPS 令牌,确保不超过每秒请求数限制 + await self._acquire_rps_token() + async with self._semaphore: try: return await self._provider.embed_texts(batch, client=client)