Add the token bucket algorithm.
Build and Push CodeReview / build (push) Waiting to run Details

This commit is contained in:
vinland100 2026-01-08 17:40:53 +08:00
parent 180ae67b7e
commit 2415b95428
1 changed files with 42 additions and 8 deletions

View File

@ -6,6 +6,7 @@
import asyncio
import hashlib
import logging
import time
from typing import List, Dict, Any, Optional
from abc import ABC, abstractmethod
from dataclasses import dataclass
@ -664,21 +665,26 @@ class EmbeddingService:
base_url=self.base_url,
)
# 🔥 控制并发请求数 (RPS 限制)
# 🔥 控制并发请求数 RPS 限制
is_remote = self.provider.lower() in ["openai", "qwen", "azure", "cohere", "jina", "huggingface"]
self.concurrency = getattr(settings, "EMBEDDING_CONCURRENCY", 2 if is_remote else 10)
# 设置最大并发数,与 RPS 保持一致以最大化吞吐
self.max_rps = getattr(settings, "EMBEDDING_RPS", 30 if is_remote else 100)
self.concurrency = getattr(settings, "EMBEDDING_CONCURRENCY", self.max_rps if is_remote else 10)
self._semaphore = asyncio.Semaphore(self.concurrency)
# 🔥 RPS 令牌桶限流器
self._rps_tokens = self.max_rps # 当前可用令牌数
self._rps_last_refill = time.monotonic() # 上次补充时间
self._rps_lock = asyncio.Lock() # 保护令牌桶的锁
# 🔥 设置默认批次大小 (DashScope text-embedding-v4 限制为 10)
self.batch_size = 10 if is_remote else 100
self.batch_size = getattr(settings, "EMBEDDING_BATCH_SIZE", 10 if is_remote else 100)
# 🔥 共享 HTTP 客户端
self._client: Optional[httpx.AsyncClient] = None
logger.info(f"Embedding service initialized with {self.provider}/{self.model} (Concurrency: {self.concurrency}, Batch size: {self.batch_size})")
logger.info(f"Embedding service initialized with {self.provider}/{self.model} (RPS: {self.max_rps}, Concurrency: {self.concurrency}, Batch size: {self.batch_size})")
async def _get_client(self) -> httpx.AsyncClient:
"""获取或创建共享的 AsyncClient"""
@ -849,6 +855,31 @@ class EmbeddingService:
# 确保没有 None
return [e if e is not None else [0.0] * self.dimension for e in embeddings]
async def _acquire_rps_token(self):
"""获取 RPS 令牌(令牌桶算法)"""
async with self._rps_lock:
now = time.monotonic()
elapsed = now - self._rps_last_refill
# 补充令牌:每秒补充 max_rps 个令牌
self._rps_tokens = min(
self.max_rps,
self._rps_tokens + elapsed * self.max_rps
)
self._rps_last_refill = now
if self._rps_tokens >= 1:
self._rps_tokens -= 1
return
# 没有令牌,计算等待时间
wait_time = (1 - self._rps_tokens) / self.max_rps
# 在锁外等待
await asyncio.sleep(wait_time)
# 递归获取令牌
await self._acquire_rps_token()
async def _process_batch_with_retry(
self,
batch: List[str],
@ -856,7 +887,7 @@ class EmbeddingService:
cancel_check: Optional[callable] = None,
max_retries: Optional[int] = None
) -> List[EmbeddingResult]:
"""带重试机制的单批次处理"""
"""带重试机制和 RPS 限流的单批次处理"""
# 优先使用配置中的重试次数
actual_max_retries = max_retries or getattr(settings, "EMBEDDING_RETRY_MAX", 5)
@ -866,6 +897,9 @@ class EmbeddingService:
if cancel_check and cancel_check():
raise asyncio.CancelledError("嵌入操作已取消")
# 🔥 先获取 RPS 令牌,确保不超过每秒请求数限制
await self._acquire_rps_token()
async with self._semaphore:
try:
return await self._provider.embed_texts(batch, client=client)