CodeReview/backend/app/services/rag/embeddings.py

932 lines
31 KiB
Python
Raw Normal View History

"""
嵌入模型服务
支持多种嵌入模型提供商: OpenAI, Azure, Ollama, Cohere, HuggingFace, Jina
"""
import asyncio
import hashlib
import logging
2026-01-08 17:40:53 +08:00
import time
from typing import List, Dict, Any, Optional
from abc import ABC, abstractmethod
from dataclasses import dataclass
import httpx
from app.core.config import settings
logger = logging.getLogger(__name__)
@dataclass
class EmbeddingResult:
"""嵌入结果"""
embedding: List[float]
tokens_used: int
model: str
class EmbeddingProvider(ABC):
"""嵌入提供商基类"""
@abstractmethod
async def embed_text(self, text: str) -> EmbeddingResult:
"""嵌入单个文本"""
pass
@abstractmethod
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
"""批量嵌入文本"""
pass
@property
@abstractmethod
def dimension(self) -> int:
"""嵌入向量维度"""
pass
class OpenAIEmbedding(EmbeddingProvider):
"""OpenAI 嵌入服务"""
MODELS = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
"Qwen3-Embedding-4B": 2560,
}
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "text-embedding-3-small",
):
# 优先使用显性参数,其次使用 EMBEDDING_API_KEY最后使用 LLM_API_KEY
self.api_key = (
api_key
or getattr(settings, "EMBEDDING_API_KEY", None)
or settings.LLM_API_KEY
)
# 优先使用显性参数,其次使用 EMBEDDING_BASE_URL最后使用 OpenAI 默认地址
self.base_url = (
base_url
or getattr(settings, "EMBEDDING_BASE_URL", None)
or "https://api.openai.com/v1"
)
self.model = model
# 优先使用显式指定的维度,其次根据模型预定义,最后默认 1536
self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0)
self._dimension = (
self._explicit_dimension
if self._explicit_dimension > 0
else self.MODELS.get(model, 1536)
)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
max_length = 8191
truncated_texts = [text[:max_length] for text in texts]
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"input": truncated_texts,
}
url = f"{self.base_url.rstrip('/')}/embeddings"
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
model=self.model,
))
return results
class AzureOpenAIEmbedding(EmbeddingProvider):
"""
Azure OpenAI 嵌入服务
使用最新 API 版本 2024-10-21 (GA)
端点格式: https://<resource>.openai.azure.com/openai/deployments/<deployment>/embeddings?api-version=2024-10-21
"""
MODELS = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
"Qwen3-Embedding-4B": 2560,
}
# 最新的 GA API 版本
API_VERSION = "2024-10-21"
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "text-embedding-3-small",
):
self.api_key = api_key
self.base_url = base_url or "https://your-resource.openai.azure.com"
self.model = model
# 优先使用项目配置中的显式维度
self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0)
self._dimension = (
self._explicit_dimension
if self._explicit_dimension > 0
else self.MODELS.get(model, 1536)
)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
max_length = 8191
truncated_texts = [text[:max_length] for text in texts]
headers = {
"api-key": self.api_key,
"Content-Type": "application/json",
}
payload = {
"input": truncated_texts,
}
# Azure URL 格式 - 使用最新 API 版本
url = f"{self.base_url.rstrip('/')}/openai/deployments/{self.model}/embeddings?api-version={self.API_VERSION}"
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
model=self.model,
))
return results
class OllamaEmbedding(EmbeddingProvider):
"""
Ollama 本地嵌入服务
使用新的 /api/embed 端点 (2024年起):
- 支持批量嵌入
- 使用 'input' 参数支持字符串或字符串数组
"""
MODELS = {
"nomic-embed-text": 768,
"mxbai-embed-large": 1024,
"all-minilm": 384,
"snowflake-arctic-embed": 1024,
"bge-m3": 1024,
"qwen3-embedding": 1024,
}
def __init__(
self,
base_url: Optional[str] = None,
model: str = "nomic-embed-text",
):
self.base_url = base_url or "http://localhost:11434"
self.model = model
self._dimension = self.MODELS.get(model, 768)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
# 新的 Ollama /api/embed 端点
url = f"{self.base_url.rstrip('/')}/api/embed"
payload = {
"model": self.model,
"input": texts, # 新 API 使用 'input' 参数,支持批量
}
if client:
response = await client.post(url, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=120) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
data = response.json()
# 新 API 返回格式: {"embeddings": [[...], [...], ...]}
embeddings = data.get("embeddings", [])
results = []
for i, embedding in enumerate(embeddings):
results.append(EmbeddingResult(
embedding=embedding,
tokens_used=len(texts[i]) // 4,
model=self.model,
))
return results
class CohereEmbedding(EmbeddingProvider):
"""
Cohere 嵌入服务
使用新的 v2 API (2024年起):
- 端点: https://api.cohere.com/v2/embed
- 使用 'inputs' 参数替代 'texts'
- 需要指定 'embedding_types'
"""
MODELS = {
"embed-english-v3.0": 1024,
"embed-multilingual-v3.0": 1024,
"embed-english-light-v3.0": 384,
"embed-multilingual-light-v3.0": 384,
"embed-v4.0": 1024, # 最新模型
}
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "embed-multilingual-v3.0",
):
self.api_key = api_key
# 新的 v2 API 端点
self.base_url = base_url or "https://api.cohere.com/v2"
self.model = model
self._dimension = self.MODELS.get(model, 1024)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
# v2 API 参数格式
payload = {
"model": self.model,
"inputs": texts, # v2 使用 'inputs' 而非 'texts'
"input_type": "search_document",
"embedding_types": ["float"], # v2 需要指定嵌入类型
}
url = f"{self.base_url.rstrip('/')}/embed"
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
# v2 API 返回格式: {"embeddings": {"float": [[...], [...]]}, ...}
embeddings_data = data.get("embeddings", {})
embeddings = embeddings_data.get("float", []) if isinstance(embeddings_data, dict) else embeddings_data
for embedding in embeddings:
results.append(EmbeddingResult(
embedding=embedding,
tokens_used=data.get("meta", {}).get("billed_units", {}).get("input_tokens", 0) // max(len(texts), 1),
model=self.model,
))
return results
class HuggingFaceEmbedding(EmbeddingProvider):
"""
HuggingFace Inference Providers 嵌入服务
使用新的 Router 端点 (2025年起):
https://router.huggingface.co/hf-inference/models/{model}/pipeline/feature-extraction
"""
MODELS = {
"sentence-transformers/all-MiniLM-L6-v2": 384,
"sentence-transformers/all-mpnet-base-v2": 768,
"BAAI/bge-large-zh-v1.5": 1024,
"BAAI/bge-m3": 1024,
"BAAI/bge-small-en-v1.5": 384,
"BAAI/bge-base-en-v1.5": 768,
}
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "BAAI/bge-m3",
):
self.api_key = api_key
# 新的 Router 端点
self.base_url = base_url or "https://router.huggingface.co"
self.model = model
self._dimension = self.MODELS.get(model, 1024)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
# 新的 HuggingFace Router URL 格式
url = f"{self.base_url.rstrip('/')}/hf-inference/models/{self.model}/pipeline/feature-extraction"
payload = {
"inputs": texts,
"options": {
"wait_for_model": True,
}
}
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=120) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
for embedding in data:
if isinstance(embedding, list) and len(embedding) > 0:
if isinstance(embedding[0], list):
embedding = embedding[0]
results.append(EmbeddingResult(
embedding=embedding,
tokens_used=len(texts[len(results)]) // 4,
model=self.model,
))
return results
class JinaEmbedding(EmbeddingProvider):
"""Jina AI 嵌入服务"""
MODELS = {
"jina-embeddings-v2-base-code": 768,
"jina-embeddings-v2-base-en": 768,
"jina-embeddings-v2-base-zh": 768,
"jina-embeddings-v2-small-en": 512,
}
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "jina-embeddings-v2-base-code",
):
self.api_key = api_key
self.base_url = base_url or "https://api.jina.ai/v1"
self.model = model
self._dimension = self.MODELS.get(model, 768)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"input": texts,
}
url = f"{self.base_url.rstrip('/')}/embeddings"
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
model=self.model,
))
return results
class QwenEmbedding(EmbeddingProvider):
"""Qwen 嵌入服务(基于阿里云 DashScope embeddings API"""
MODELS = {
# DashScope Qwen 嵌入模型及其默认维度
"text-embedding-v4": 1024, # 支持维度: 2048, 1536, 1024(默认), 768, 512, 256, 128, 64
"text-embedding-v3": 1024, # 支持维度: 1024(默认), 768, 512, 256, 128, 64
"text-embedding-v2": 1536, # 支持维度: 1536
"Qwen3-Embedding-4B": 2560,
}
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "text-embedding-v4",
):
# 优先使用显式传入的 api_key其次使用 EMBEDDING_API_KEY/QWEN_API_KEY/LLM_API_KEY
self.api_key = (
api_key
or getattr(settings, "EMBEDDING_API_KEY", None)
or getattr(settings, "QWEN_API_KEY", None)
or settings.LLM_API_KEY
)
# 🔥 API 密钥验证
if not self.api_key:
raise ValueError(
"Qwen embedding requires API key. "
"Set EMBEDDING_API_KEY, QWEN_API_KEY or LLM_API_KEY environment variable."
)
# DashScope 兼容 OpenAI 的 embeddings 端点
self.base_url = base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1"
self.model = model
# 优先使用显式指定的维度,其次根据模型预定义,最后默认 1024
self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0)
self._dimension = (
self._explicit_dimension
if self._explicit_dimension > 0
else self.MODELS.get(model, 1024)
)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
# 与 OpenAI 接口保持一致的截断策略
max_length = 8191
truncated_texts = [text[:max_length] for text in texts]
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"input": truncated_texts,
}
url = f"{self.base_url.rstrip('/')}/embeddings"
try:
if client:
response = await client.post(url, headers=headers, json=payload)
else:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
usage = data.get("usage", {}) or {}
total_tokens = usage.get("total_tokens") or usage.get("prompt_tokens") or 0
results: List[EmbeddingResult] = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=total_tokens // max(len(texts), 1),
model=self.model,
))
return results
except httpx.HTTPStatusError as e:
logger.error(f"Qwen embedding API error: {e.response.status_code} - {e.response.text}")
raise
except httpx.RequestError as e:
logger.error(f"Qwen embedding network error: {e}")
raise RuntimeError(f"Qwen embedding network error: {e}") from e
except Exception as e:
logger.error(f"Qwen embedding unexpected error: {e}")
raise RuntimeError(f"Qwen embedding failed: {e}") from e
class EmbeddingService:
"""
嵌入服务
统一管理嵌入模型和缓存
支持的提供商:
- openai: OpenAI 官方
- azure: Azure OpenAI
- ollama: Ollama 本地
- cohere: Cohere
- huggingface: HuggingFace Inference API
- jina: Jina AI
"""
def __init__(
self,
provider: Optional[str] = None,
model: Optional[str] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
cache_enabled: bool = True,
):
"""
初始化嵌入服务
Args:
provider: 提供商 (openai, azure, ollama, cohere, huggingface, jina)
model: 模型名称
api_key: API Key
base_url: API Base URL
cache_enabled: 是否启用缓存
"""
self.cache_enabled = cache_enabled
self._cache: Dict[str, List[float]] = {}
# 确定提供商(保存原始值用于属性访问)
self.provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
self.model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
self.api_key = api_key or getattr(settings, "EMBEDDING_API_KEY", None)
self.base_url = base_url or getattr(settings, "EMBEDDING_BASE_URL", None)
# 创建提供商实例
self._provider = self._create_provider(
provider=self.provider,
model=self.model,
api_key=self.api_key,
base_url=self.base_url,
)
2026-01-08 17:40:53 +08:00
# 🔥 控制并发请求数和 RPS 限制
is_remote = self.provider.lower() in ["openai", "qwen", "azure", "cohere", "jina", "huggingface"]
2026-01-08 17:40:53 +08:00
# 设置最大并发数,与 RPS 保持一致以最大化吞吐
self.max_rps = getattr(settings, "EMBEDDING_RPS", 30 if is_remote else 100)
self.concurrency = getattr(settings, "EMBEDDING_CONCURRENCY", self.max_rps if is_remote else 10)
self._semaphore = asyncio.Semaphore(self.concurrency)
2026-01-08 17:40:53 +08:00
# 🔥 RPS 令牌桶限流器
self._rps_tokens = self.max_rps # 当前可用令牌数
self._rps_last_refill = time.monotonic() # 上次补充时间
self._rps_lock = asyncio.Lock() # 保护令牌桶的锁
# 🔥 设置默认批次大小 (DashScope text-embedding-v4 限制为 10)
2026-01-08 17:40:53 +08:00
self.batch_size = getattr(settings, "EMBEDDING_BATCH_SIZE", 10 if is_remote else 100)
# 🔥 共享 HTTP 客户端
self._client: Optional[httpx.AsyncClient] = None
2026-01-08 17:40:53 +08:00
logger.info(f"Embedding service initialized with {self.provider}/{self.model} (RPS: {self.max_rps}, Concurrency: {self.concurrency}, Batch size: {self.batch_size})")
async def _get_client(self) -> httpx.AsyncClient:
"""获取或创建共享的 AsyncClient"""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(timeout=120.0)
return self._client
async def close(self):
"""关闭共享客户端"""
if self._client and not self._client.is_closed:
await self._client.aclose()
self._client = None
def _create_provider(
self,
provider: str,
model: str,
api_key: Optional[str],
base_url: Optional[str],
) -> EmbeddingProvider:
"""创建嵌入提供商实例"""
provider = provider.lower()
if provider == "ollama":
return OllamaEmbedding(base_url=base_url, model=model)
elif provider == "azure":
return AzureOpenAIEmbedding(api_key=api_key, base_url=base_url, model=model)
elif provider == "cohere":
return CohereEmbedding(api_key=api_key, base_url=base_url, model=model)
elif provider == "huggingface":
return HuggingFaceEmbedding(api_key=api_key, base_url=base_url, model=model)
elif provider == "jina":
return JinaEmbedding(api_key=api_key, base_url=base_url, model=model)
elif provider == "qwen":
return QwenEmbedding(api_key=api_key, base_url=base_url, model=model)
else:
# 默认使用 OpenAI
return OpenAIEmbedding(api_key=api_key, base_url=base_url, model=model)
@property
def dimension(self) -> int:
"""嵌入向量维度"""
return self._provider.dimension
def _cache_key(self, text: str) -> str:
"""生成缓存键"""
return hashlib.sha256(text.encode()).hexdigest()[:32]
async def embed(self, text: str) -> List[float]:
"""
嵌入单个文本
Args:
text: 文本内容
Returns:
嵌入向量
"""
if not text or not text.strip():
return [0.0] * self.dimension
# 检查缓存
if self.cache_enabled:
cache_key = self._cache_key(text)
if cache_key in self._cache:
return self._cache[cache_key]
result = await self._provider.embed_text(text)
# 存入缓存
if self.cache_enabled:
self._cache[cache_key] = result.embedding
return result.embedding
async def embed_batch(
self,
texts: List[str],
batch_size: int = 100,
show_progress: bool = False,
progress_callback: Optional[callable] = None,
cancel_check: Optional[callable] = None,
) -> List[List[float]]:
"""
批量嵌入文本
Args:
texts: 文本列表
batch_size: 批次大小
show_progress: 是否显示进度
progress_callback: 进度回调函数接收 (processed, total) 参数
cancel_check: 取消检查函数返回 True 表示应该取消
Returns:
嵌入向量列表
Raises:
asyncio.CancelledError: cancel_check 返回 True
"""
if not texts:
return []
embeddings = []
uncached_indices = []
uncached_texts = []
# 检查缓存
for i, text in enumerate(texts):
if not text or not text.strip():
embeddings.append([0.0] * self.dimension)
continue
if self.cache_enabled:
cache_key = self._cache_key(text)
if cache_key in self._cache:
embeddings.append(self._cache[cache_key])
continue
embeddings.append(None) # 占位
uncached_indices.append(i)
uncached_texts.append(text)
# 批量处理未缓存的文本
if uncached_texts:
tasks = []
current_batch_size = batch_size or self.batch_size
for i in range(0, len(uncached_texts), current_batch_size):
batch = uncached_texts[i:i + current_batch_size]
batch_indices = uncached_indices[i:i + current_batch_size]
tasks.append(self._process_batch_with_retry(batch, batch_indices, cancel_check))
# 🔥 并发执行所有批次任务
all_batch_results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result_list in enumerate(all_batch_results):
batch_indices = uncached_indices[i * current_batch_size : (i + 1) * current_batch_size]
if isinstance(result_list, Exception):
logger.error(f"Batch processing failed: {result_list}")
# 失败批次使用零向量
for idx in batch_indices:
if embeddings[idx] is None:
embeddings[idx] = [0.0] * self.dimension
continue
for idx, result in zip(batch_indices, result_list):
embeddings[idx] = result.embedding
# 存入缓存
if self.cache_enabled:
cache_key = self._cache_key(texts[idx])
self._cache[cache_key] = result.embedding
# 🔥 调用进度回调
if progress_callback:
processed_count = min((i + 1) * current_batch_size, len(uncached_texts))
try:
progress_callback(processed_count, len(uncached_texts))
except Exception as e:
logger.warning(f"Progress callback error: {e}")
# 确保没有 None
return [e if e is not None else [0.0] * self.dimension for e in embeddings]
2026-01-08 17:40:53 +08:00
async def _acquire_rps_token(self):
"""获取 RPS 令牌(令牌桶算法)"""
async with self._rps_lock:
now = time.monotonic()
elapsed = now - self._rps_last_refill
# 补充令牌:每秒补充 max_rps 个令牌
self._rps_tokens = min(
self.max_rps,
self._rps_tokens + elapsed * self.max_rps
)
self._rps_last_refill = now
if self._rps_tokens >= 1:
self._rps_tokens -= 1
return
# 没有令牌,计算等待时间
wait_time = (1 - self._rps_tokens) / self.max_rps
# 在锁外等待
await asyncio.sleep(wait_time)
# 递归获取令牌
await self._acquire_rps_token()
async def _process_batch_with_retry(
self,
batch: List[str],
indices: List[int],
cancel_check: Optional[callable] = None,
max_retries: Optional[int] = None
) -> List[EmbeddingResult]:
2026-01-08 17:40:53 +08:00
"""带重试机制和 RPS 限流的单批次处理"""
# 优先使用配置中的重试次数
actual_max_retries = max_retries or getattr(settings, "EMBEDDING_RETRY_MAX", 5)
client = await self._get_client()
for attempt in range(actual_max_retries):
if cancel_check and cancel_check():
raise asyncio.CancelledError("嵌入操作已取消")
2026-01-08 17:40:53 +08:00
# 🔥 先获取 RPS 令牌,确保不超过每秒请求数限制
await self._acquire_rps_token()
async with self._semaphore:
try:
return await self._provider.embed_texts(batch, client=client)
except httpx.HTTPStatusError as e:
if e.response.status_code == 429 and attempt < actual_max_retries - 1:
# 429 限流,使用更保守的指数退避
# 第一次重试等待 5s, 第二次 10s, 第三次 20s...
wait_time = (2 ** attempt) * 5
logger.warning(f"Rate limited (429), retrying in {wait_time}s... (Attempt {attempt+1}/{actual_max_retries})")
await asyncio.sleep(wait_time)
continue
raise
except Exception as e:
if attempt < actual_max_retries - 1:
# 普通错误等待 2s
await asyncio.sleep(2)
continue
raise
return []
def clear_cache(self):
"""清空缓存"""
self._cache.clear()
@property
def cache_size(self) -> int:
"""缓存大小"""
return len(self._cache)