Add the token bucket algorithm.
Build and Push CodeReview / build (push) Waiting to run
Details
Build and Push CodeReview / build (push) Waiting to run
Details
This commit is contained in:
parent
180ae67b7e
commit
2415b95428
|
|
@ -6,6 +6,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
@ -664,21 +665,26 @@ class EmbeddingService:
|
||||||
base_url=self.base_url,
|
base_url=self.base_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 🔥 控制并发请求数 (RPS 限制)
|
# 🔥 控制并发请求数和 RPS 限制
|
||||||
is_remote = self.provider.lower() in ["openai", "qwen", "azure", "cohere", "jina", "huggingface"]
|
is_remote = self.provider.lower() in ["openai", "qwen", "azure", "cohere", "jina", "huggingface"]
|
||||||
self.concurrency = getattr(settings, "EMBEDDING_CONCURRENCY", 2 if is_remote else 10)
|
|
||||||
|
# 设置最大并发数,与 RPS 保持一致以最大化吞吐
|
||||||
|
self.max_rps = getattr(settings, "EMBEDDING_RPS", 30 if is_remote else 100)
|
||||||
|
self.concurrency = getattr(settings, "EMBEDDING_CONCURRENCY", self.max_rps if is_remote else 10)
|
||||||
self._semaphore = asyncio.Semaphore(self.concurrency)
|
self._semaphore = asyncio.Semaphore(self.concurrency)
|
||||||
|
|
||||||
|
# 🔥 RPS 令牌桶限流器
|
||||||
|
self._rps_tokens = self.max_rps # 当前可用令牌数
|
||||||
|
self._rps_last_refill = time.monotonic() # 上次补充时间
|
||||||
|
self._rps_lock = asyncio.Lock() # 保护令牌桶的锁
|
||||||
|
|
||||||
# 🔥 设置默认批次大小 (DashScope text-embedding-v4 限制为 10)
|
# 🔥 设置默认批次大小 (DashScope text-embedding-v4 限制为 10)
|
||||||
self.batch_size = 10 if is_remote else 100
|
self.batch_size = getattr(settings, "EMBEDDING_BATCH_SIZE", 10 if is_remote else 100)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 🔥 共享 HTTP 客户端
|
# 🔥 共享 HTTP 客户端
|
||||||
self._client: Optional[httpx.AsyncClient] = None
|
self._client: Optional[httpx.AsyncClient] = None
|
||||||
|
|
||||||
logger.info(f"Embedding service initialized with {self.provider}/{self.model} (Concurrency: {self.concurrency}, Batch size: {self.batch_size})")
|
logger.info(f"Embedding service initialized with {self.provider}/{self.model} (RPS: {self.max_rps}, Concurrency: {self.concurrency}, Batch size: {self.batch_size})")
|
||||||
|
|
||||||
async def _get_client(self) -> httpx.AsyncClient:
|
async def _get_client(self) -> httpx.AsyncClient:
|
||||||
"""获取或创建共享的 AsyncClient"""
|
"""获取或创建共享的 AsyncClient"""
|
||||||
|
|
@ -849,6 +855,31 @@ class EmbeddingService:
|
||||||
# 确保没有 None
|
# 确保没有 None
|
||||||
return [e if e is not None else [0.0] * self.dimension for e in embeddings]
|
return [e if e is not None else [0.0] * self.dimension for e in embeddings]
|
||||||
|
|
||||||
|
async def _acquire_rps_token(self):
|
||||||
|
"""获取 RPS 令牌(令牌桶算法)"""
|
||||||
|
async with self._rps_lock:
|
||||||
|
now = time.monotonic()
|
||||||
|
elapsed = now - self._rps_last_refill
|
||||||
|
|
||||||
|
# 补充令牌:每秒补充 max_rps 个令牌
|
||||||
|
self._rps_tokens = min(
|
||||||
|
self.max_rps,
|
||||||
|
self._rps_tokens + elapsed * self.max_rps
|
||||||
|
)
|
||||||
|
self._rps_last_refill = now
|
||||||
|
|
||||||
|
if self._rps_tokens >= 1:
|
||||||
|
self._rps_tokens -= 1
|
||||||
|
return
|
||||||
|
|
||||||
|
# 没有令牌,计算等待时间
|
||||||
|
wait_time = (1 - self._rps_tokens) / self.max_rps
|
||||||
|
|
||||||
|
# 在锁外等待
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
# 递归获取令牌
|
||||||
|
await self._acquire_rps_token()
|
||||||
|
|
||||||
async def _process_batch_with_retry(
|
async def _process_batch_with_retry(
|
||||||
self,
|
self,
|
||||||
batch: List[str],
|
batch: List[str],
|
||||||
|
|
@ -856,7 +887,7 @@ class EmbeddingService:
|
||||||
cancel_check: Optional[callable] = None,
|
cancel_check: Optional[callable] = None,
|
||||||
max_retries: Optional[int] = None
|
max_retries: Optional[int] = None
|
||||||
) -> List[EmbeddingResult]:
|
) -> List[EmbeddingResult]:
|
||||||
"""带重试机制的单批次处理"""
|
"""带重试机制和 RPS 限流的单批次处理"""
|
||||||
# 优先使用配置中的重试次数
|
# 优先使用配置中的重试次数
|
||||||
actual_max_retries = max_retries or getattr(settings, "EMBEDDING_RETRY_MAX", 5)
|
actual_max_retries = max_retries or getattr(settings, "EMBEDDING_RETRY_MAX", 5)
|
||||||
|
|
||||||
|
|
@ -866,6 +897,9 @@ class EmbeddingService:
|
||||||
if cancel_check and cancel_check():
|
if cancel_check and cancel_check():
|
||||||
raise asyncio.CancelledError("嵌入操作已取消")
|
raise asyncio.CancelledError("嵌入操作已取消")
|
||||||
|
|
||||||
|
# 🔥 先获取 RPS 令牌,确保不超过每秒请求数限制
|
||||||
|
await self._acquire_rps_token()
|
||||||
|
|
||||||
async with self._semaphore:
|
async with self._semaphore:
|
||||||
try:
|
try:
|
||||||
return await self._provider.embed_texts(batch, client=client)
|
return await self._provider.embed_texts(batch, client=client)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue