Add the token bucket algorithm.
Build and Push CodeReview / build (push) Waiting to run Details

This commit is contained in:
vinland100 2026-01-08 17:40:53 +08:00
parent 180ae67b7e
commit 2415b95428
1 changed files with 42 additions and 8 deletions

View File

@ -6,6 +6,7 @@
import asyncio import asyncio
import hashlib import hashlib
import logging import logging
import time
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
@ -664,21 +665,26 @@ class EmbeddingService:
base_url=self.base_url, base_url=self.base_url,
) )
# 🔥 控制并发请求数 (RPS 限制) # 🔥 控制并发请求数 RPS 限制
is_remote = self.provider.lower() in ["openai", "qwen", "azure", "cohere", "jina", "huggingface"] is_remote = self.provider.lower() in ["openai", "qwen", "azure", "cohere", "jina", "huggingface"]
self.concurrency = getattr(settings, "EMBEDDING_CONCURRENCY", 2 if is_remote else 10)
# 设置最大并发数,与 RPS 保持一致以最大化吞吐
self.max_rps = getattr(settings, "EMBEDDING_RPS", 30 if is_remote else 100)
self.concurrency = getattr(settings, "EMBEDDING_CONCURRENCY", self.max_rps if is_remote else 10)
self._semaphore = asyncio.Semaphore(self.concurrency) self._semaphore = asyncio.Semaphore(self.concurrency)
# 🔥 RPS 令牌桶限流器
self._rps_tokens = self.max_rps # 当前可用令牌数
self._rps_last_refill = time.monotonic() # 上次补充时间
self._rps_lock = asyncio.Lock() # 保护令牌桶的锁
# 🔥 设置默认批次大小 (DashScope text-embedding-v4 限制为 10) # 🔥 设置默认批次大小 (DashScope text-embedding-v4 限制为 10)
self.batch_size = 10 if is_remote else 100 self.batch_size = getattr(settings, "EMBEDDING_BATCH_SIZE", 10 if is_remote else 100)
# 🔥 共享 HTTP 客户端 # 🔥 共享 HTTP 客户端
self._client: Optional[httpx.AsyncClient] = None self._client: Optional[httpx.AsyncClient] = None
logger.info(f"Embedding service initialized with {self.provider}/{self.model} (Concurrency: {self.concurrency}, Batch size: {self.batch_size})") logger.info(f"Embedding service initialized with {self.provider}/{self.model} (RPS: {self.max_rps}, Concurrency: {self.concurrency}, Batch size: {self.batch_size})")
async def _get_client(self) -> httpx.AsyncClient: async def _get_client(self) -> httpx.AsyncClient:
"""获取或创建共享的 AsyncClient""" """获取或创建共享的 AsyncClient"""
@ -849,6 +855,31 @@ class EmbeddingService:
# 确保没有 None # 确保没有 None
return [e if e is not None else [0.0] * self.dimension for e in embeddings] return [e if e is not None else [0.0] * self.dimension for e in embeddings]
async def _acquire_rps_token(self):
"""获取 RPS 令牌(令牌桶算法)"""
async with self._rps_lock:
now = time.monotonic()
elapsed = now - self._rps_last_refill
# 补充令牌:每秒补充 max_rps 个令牌
self._rps_tokens = min(
self.max_rps,
self._rps_tokens + elapsed * self.max_rps
)
self._rps_last_refill = now
if self._rps_tokens >= 1:
self._rps_tokens -= 1
return
# 没有令牌,计算等待时间
wait_time = (1 - self._rps_tokens) / self.max_rps
# 在锁外等待
await asyncio.sleep(wait_time)
# 递归获取令牌
await self._acquire_rps_token()
async def _process_batch_with_retry( async def _process_batch_with_retry(
self, self,
batch: List[str], batch: List[str],
@ -856,7 +887,7 @@ class EmbeddingService:
cancel_check: Optional[callable] = None, cancel_check: Optional[callable] = None,
max_retries: Optional[int] = None max_retries: Optional[int] = None
) -> List[EmbeddingResult]: ) -> List[EmbeddingResult]:
"""带重试机制的单批次处理""" """带重试机制和 RPS 限流的单批次处理"""
# 优先使用配置中的重试次数 # 优先使用配置中的重试次数
actual_max_retries = max_retries or getattr(settings, "EMBEDDING_RETRY_MAX", 5) actual_max_retries = max_retries or getattr(settings, "EMBEDDING_RETRY_MAX", 5)
@ -866,6 +897,9 @@ class EmbeddingService:
if cancel_check and cancel_check(): if cancel_check and cancel_check():
raise asyncio.CancelledError("嵌入操作已取消") raise asyncio.CancelledError("嵌入操作已取消")
# 🔥 先获取 RPS 令牌,确保不超过每秒请求数限制
await self._acquire_rps_token()
async with self._semaphore: async with self._semaphore:
try: try:
return await self._provider.embed_texts(batch, client=client) return await self._provider.embed_texts(batch, client=client)