CodeReview/backend/app/services/rag/embeddings.py

898 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
嵌入模型服务
支持多种嵌入模型提供商: OpenAI, Azure, Ollama, Cohere, HuggingFace, Jina
"""
import asyncio
import hashlib
import logging
from typing import List, Dict, Any, Optional
from abc import ABC, abstractmethod
from dataclasses import dataclass
import httpx
from app.core.config import settings
logger = logging.getLogger(__name__)
@dataclass
class EmbeddingResult:
"""嵌入结果"""
embedding: List[float]
tokens_used: int
model: str
class EmbeddingProvider(ABC):
"""嵌入提供商基类"""
@abstractmethod
async def embed_text(self, text: str) -> EmbeddingResult:
"""嵌入单个文本"""
pass
@abstractmethod
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
"""批量嵌入文本"""
pass
@property
@abstractmethod
def dimension(self) -> int:
"""嵌入向量维度"""
pass
class OpenAIEmbedding(EmbeddingProvider):
"""OpenAI 嵌入服务"""
MODELS = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
"Qwen3-Embedding-4B": 2560,
}
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "text-embedding-3-small",
):
# 优先使用显性参数,其次使用 EMBEDDING_API_KEY最后使用 LLM_API_KEY
self.api_key = (
api_key
or getattr(settings, "EMBEDDING_API_KEY", None)
or settings.LLM_API_KEY
)
# 优先使用显性参数,其次使用 EMBEDDING_BASE_URL最后使用 OpenAI 默认地址
self.base_url = (
base_url
or getattr(settings, "EMBEDDING_BASE_URL", None)
or "https://api.openai.com/v1"
)
self.model = model
# 优先使用显式指定的维度,其次根据模型预定义,最后默认 1536
self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0)
self._dimension = (
self._explicit_dimension
if self._explicit_dimension > 0
else self.MODELS.get(model, 1536)
)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
max_length = 8191
truncated_texts = [text[:max_length] for text in texts]
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"input": truncated_texts,
}
url = f"{self.base_url.rstrip('/')}/embeddings"
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
model=self.model,
))
return results
class AzureOpenAIEmbedding(EmbeddingProvider):
"""
Azure OpenAI 嵌入服务
使用最新 API 版本 2024-10-21 (GA)
端点格式: https://<resource>.openai.azure.com/openai/deployments/<deployment>/embeddings?api-version=2024-10-21
"""
MODELS = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
"Qwen3-Embedding-4B": 2560,
}
# 最新的 GA API 版本
API_VERSION = "2024-10-21"
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "text-embedding-3-small",
):
self.api_key = api_key
self.base_url = base_url or "https://your-resource.openai.azure.com"
self.model = model
# 优先使用项目配置中的显式维度
self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0)
self._dimension = (
self._explicit_dimension
if self._explicit_dimension > 0
else self.MODELS.get(model, 1536)
)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
max_length = 8191
truncated_texts = [text[:max_length] for text in texts]
headers = {
"api-key": self.api_key,
"Content-Type": "application/json",
}
payload = {
"input": truncated_texts,
}
# Azure URL 格式 - 使用最新 API 版本
url = f"{self.base_url.rstrip('/')}/openai/deployments/{self.model}/embeddings?api-version={self.API_VERSION}"
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
model=self.model,
))
return results
class OllamaEmbedding(EmbeddingProvider):
"""
Ollama 本地嵌入服务
使用新的 /api/embed 端点 (2024年起):
- 支持批量嵌入
- 使用 'input' 参数(支持字符串或字符串数组)
"""
MODELS = {
"nomic-embed-text": 768,
"mxbai-embed-large": 1024,
"all-minilm": 384,
"snowflake-arctic-embed": 1024,
"bge-m3": 1024,
"qwen3-embedding": 1024,
}
def __init__(
self,
base_url: Optional[str] = None,
model: str = "nomic-embed-text",
):
self.base_url = base_url or "http://localhost:11434"
self.model = model
self._dimension = self.MODELS.get(model, 768)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
# 新的 Ollama /api/embed 端点
url = f"{self.base_url.rstrip('/')}/api/embed"
payload = {
"model": self.model,
"input": texts, # 新 API 使用 'input' 参数,支持批量
}
if client:
response = await client.post(url, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=120) as client:
response = await client.post(url, json=payload)
response.raise_for_status()
data = response.json()
# 新 API 返回格式: {"embeddings": [[...], [...], ...]}
embeddings = data.get("embeddings", [])
results = []
for i, embedding in enumerate(embeddings):
results.append(EmbeddingResult(
embedding=embedding,
tokens_used=len(texts[i]) // 4,
model=self.model,
))
return results
class CohereEmbedding(EmbeddingProvider):
"""
Cohere 嵌入服务
使用新的 v2 API (2024年起):
- 端点: https://api.cohere.com/v2/embed
- 使用 'inputs' 参数替代 'texts'
- 需要指定 'embedding_types'
"""
MODELS = {
"embed-english-v3.0": 1024,
"embed-multilingual-v3.0": 1024,
"embed-english-light-v3.0": 384,
"embed-multilingual-light-v3.0": 384,
"embed-v4.0": 1024, # 最新模型
}
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "embed-multilingual-v3.0",
):
self.api_key = api_key
# 新的 v2 API 端点
self.base_url = base_url or "https://api.cohere.com/v2"
self.model = model
self._dimension = self.MODELS.get(model, 1024)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
# v2 API 参数格式
payload = {
"model": self.model,
"inputs": texts, # v2 使用 'inputs' 而非 'texts'
"input_type": "search_document",
"embedding_types": ["float"], # v2 需要指定嵌入类型
}
url = f"{self.base_url.rstrip('/')}/embed"
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
# v2 API 返回格式: {"embeddings": {"float": [[...], [...]]}, ...}
embeddings_data = data.get("embeddings", {})
embeddings = embeddings_data.get("float", []) if isinstance(embeddings_data, dict) else embeddings_data
for embedding in embeddings:
results.append(EmbeddingResult(
embedding=embedding,
tokens_used=data.get("meta", {}).get("billed_units", {}).get("input_tokens", 0) // max(len(texts), 1),
model=self.model,
))
return results
class HuggingFaceEmbedding(EmbeddingProvider):
"""
HuggingFace Inference Providers 嵌入服务
使用新的 Router 端点 (2025年起):
https://router.huggingface.co/hf-inference/models/{model}/pipeline/feature-extraction
"""
MODELS = {
"sentence-transformers/all-MiniLM-L6-v2": 384,
"sentence-transformers/all-mpnet-base-v2": 768,
"BAAI/bge-large-zh-v1.5": 1024,
"BAAI/bge-m3": 1024,
"BAAI/bge-small-en-v1.5": 384,
"BAAI/bge-base-en-v1.5": 768,
}
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "BAAI/bge-m3",
):
self.api_key = api_key
# 新的 Router 端点
self.base_url = base_url or "https://router.huggingface.co"
self.model = model
self._dimension = self.MODELS.get(model, 1024)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
# 新的 HuggingFace Router URL 格式
url = f"{self.base_url.rstrip('/')}/hf-inference/models/{self.model}/pipeline/feature-extraction"
payload = {
"inputs": texts,
"options": {
"wait_for_model": True,
}
}
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=120) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
for embedding in data:
if isinstance(embedding, list) and len(embedding) > 0:
if isinstance(embedding[0], list):
embedding = embedding[0]
results.append(EmbeddingResult(
embedding=embedding,
tokens_used=len(texts[len(results)]) // 4,
model=self.model,
))
return results
class JinaEmbedding(EmbeddingProvider):
"""Jina AI 嵌入服务"""
MODELS = {
"jina-embeddings-v2-base-code": 768,
"jina-embeddings-v2-base-en": 768,
"jina-embeddings-v2-base-zh": 768,
"jina-embeddings-v2-small-en": 512,
}
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "jina-embeddings-v2-base-code",
):
self.api_key = api_key
self.base_url = base_url or "https://api.jina.ai/v1"
self.model = model
self._dimension = self.MODELS.get(model, 768)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"input": texts,
}
url = f"{self.base_url.rstrip('/')}/embeddings"
if client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
else:
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts),
model=self.model,
))
return results
class QwenEmbedding(EmbeddingProvider):
"""Qwen 嵌入服务(基于阿里云 DashScope embeddings API"""
MODELS = {
# DashScope Qwen 嵌入模型及其默认维度
"text-embedding-v4": 1024, # 支持维度: 2048, 1536, 1024(默认), 768, 512, 256, 128, 64
"text-embedding-v3": 1024, # 支持维度: 1024(默认), 768, 512, 256, 128, 64
"text-embedding-v2": 1536, # 支持维度: 1536
"Qwen3-Embedding-4B": 2560,
}
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "text-embedding-v4",
):
# 优先使用显式传入的 api_key其次使用 EMBEDDING_API_KEY/QWEN_API_KEY/LLM_API_KEY
self.api_key = (
api_key
or getattr(settings, "EMBEDDING_API_KEY", None)
or getattr(settings, "QWEN_API_KEY", None)
or settings.LLM_API_KEY
)
# 🔥 API 密钥验证
if not self.api_key:
raise ValueError(
"Qwen embedding requires API key. "
"Set EMBEDDING_API_KEY, QWEN_API_KEY or LLM_API_KEY environment variable."
)
# DashScope 兼容 OpenAI 的 embeddings 端点
self.base_url = base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1"
self.model = model
# 优先使用显式指定的维度,其次根据模型预定义,最后默认 1024
self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0)
self._dimension = (
self._explicit_dimension
if self._explicit_dimension > 0
else self.MODELS.get(model, 1024)
)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]:
if not texts:
return []
# 与 OpenAI 接口保持一致的截断策略
max_length = 8191
truncated_texts = [text[:max_length] for text in texts]
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"input": truncated_texts,
}
url = f"{self.base_url.rstrip('/')}/embeddings"
try:
if client:
response = await client.post(url, headers=headers, json=payload)
else:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
usage = data.get("usage", {}) or {}
total_tokens = usage.get("total_tokens") or usage.get("prompt_tokens") or 0
results: List[EmbeddingResult] = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=total_tokens // max(len(texts), 1),
model=self.model,
))
return results
except httpx.HTTPStatusError as e:
logger.error(f"Qwen embedding API error: {e.response.status_code} - {e.response.text}")
raise
except httpx.RequestError as e:
logger.error(f"Qwen embedding network error: {e}")
raise RuntimeError(f"Qwen embedding network error: {e}") from e
except Exception as e:
logger.error(f"Qwen embedding unexpected error: {e}")
raise RuntimeError(f"Qwen embedding failed: {e}") from e
class EmbeddingService:
"""
嵌入服务
统一管理嵌入模型和缓存
支持的提供商:
- openai: OpenAI 官方
- azure: Azure OpenAI
- ollama: Ollama 本地
- cohere: Cohere
- huggingface: HuggingFace Inference API
- jina: Jina AI
"""
def __init__(
self,
provider: Optional[str] = None,
model: Optional[str] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
cache_enabled: bool = True,
):
"""
初始化嵌入服务
Args:
provider: 提供商 (openai, azure, ollama, cohere, huggingface, jina)
model: 模型名称
api_key: API Key
base_url: API Base URL
cache_enabled: 是否启用缓存
"""
self.cache_enabled = cache_enabled
self._cache: Dict[str, List[float]] = {}
# 确定提供商(保存原始值用于属性访问)
self.provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
self.model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
self.api_key = api_key or getattr(settings, "EMBEDDING_API_KEY", None)
self.base_url = base_url or getattr(settings, "EMBEDDING_BASE_URL", None)
# 创建提供商实例
self._provider = self._create_provider(
provider=self.provider,
model=self.model,
api_key=self.api_key,
base_url=self.base_url,
)
# 🔥 控制并发请求数 (RPS 限制)
is_remote = self.provider.lower() in ["openai", "qwen", "azure", "cohere", "jina", "huggingface"]
self.concurrency = getattr(settings, "EMBEDDING_CONCURRENCY", 2 if is_remote else 10)
self._semaphore = asyncio.Semaphore(self.concurrency)
# 🔥 设置默认批次大小 (DashScope text-embedding-v4 限制为 10)
self.batch_size = 10 if is_remote else 100
# 🔥 共享 HTTP 客户端
self._client: Optional[httpx.AsyncClient] = None
logger.info(f"Embedding service initialized with {self.provider}/{self.model} (Concurrency: {self.concurrency}, Batch size: {self.batch_size})")
async def _get_client(self) -> httpx.AsyncClient:
"""获取或创建共享的 AsyncClient"""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(timeout=120.0)
return self._client
async def close(self):
"""关闭共享客户端"""
if self._client and not self._client.is_closed:
await self._client.aclose()
self._client = None
def _create_provider(
self,
provider: str,
model: str,
api_key: Optional[str],
base_url: Optional[str],
) -> EmbeddingProvider:
"""创建嵌入提供商实例"""
provider = provider.lower()
if provider == "ollama":
return OllamaEmbedding(base_url=base_url, model=model)
elif provider == "azure":
return AzureOpenAIEmbedding(api_key=api_key, base_url=base_url, model=model)
elif provider == "cohere":
return CohereEmbedding(api_key=api_key, base_url=base_url, model=model)
elif provider == "huggingface":
return HuggingFaceEmbedding(api_key=api_key, base_url=base_url, model=model)
elif provider == "jina":
return JinaEmbedding(api_key=api_key, base_url=base_url, model=model)
elif provider == "qwen":
return QwenEmbedding(api_key=api_key, base_url=base_url, model=model)
else:
# 默认使用 OpenAI
return OpenAIEmbedding(api_key=api_key, base_url=base_url, model=model)
@property
def dimension(self) -> int:
"""嵌入向量维度"""
return self._provider.dimension
def _cache_key(self, text: str) -> str:
"""生成缓存键"""
return hashlib.sha256(text.encode()).hexdigest()[:32]
async def embed(self, text: str) -> List[float]:
"""
嵌入单个文本
Args:
text: 文本内容
Returns:
嵌入向量
"""
if not text or not text.strip():
return [0.0] * self.dimension
# 检查缓存
if self.cache_enabled:
cache_key = self._cache_key(text)
if cache_key in self._cache:
return self._cache[cache_key]
result = await self._provider.embed_text(text)
# 存入缓存
if self.cache_enabled:
self._cache[cache_key] = result.embedding
return result.embedding
async def embed_batch(
self,
texts: List[str],
batch_size: int = 100,
show_progress: bool = False,
progress_callback: Optional[callable] = None,
cancel_check: Optional[callable] = None,
) -> List[List[float]]:
"""
批量嵌入文本
Args:
texts: 文本列表
batch_size: 批次大小
show_progress: 是否显示进度
progress_callback: 进度回调函数,接收 (processed, total) 参数
cancel_check: 取消检查函数,返回 True 表示应该取消
Returns:
嵌入向量列表
Raises:
asyncio.CancelledError: 当 cancel_check 返回 True 时
"""
if not texts:
return []
embeddings = []
uncached_indices = []
uncached_texts = []
# 检查缓存
for i, text in enumerate(texts):
if not text or not text.strip():
embeddings.append([0.0] * self.dimension)
continue
if self.cache_enabled:
cache_key = self._cache_key(text)
if cache_key in self._cache:
embeddings.append(self._cache[cache_key])
continue
embeddings.append(None) # 占位
uncached_indices.append(i)
uncached_texts.append(text)
# 批量处理未缓存的文本
if uncached_texts:
tasks = []
current_batch_size = batch_size or self.batch_size
for i in range(0, len(uncached_texts), current_batch_size):
batch = uncached_texts[i:i + current_batch_size]
batch_indices = uncached_indices[i:i + current_batch_size]
tasks.append(self._process_batch_with_retry(batch, batch_indices, cancel_check))
# 🔥 并发执行所有批次任务
all_batch_results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result_list in enumerate(all_batch_results):
batch_indices = uncached_indices[i * current_batch_size : (i + 1) * current_batch_size]
if isinstance(result_list, Exception):
logger.error(f"Batch processing failed: {result_list}")
# 失败批次使用零向量
for idx in batch_indices:
if embeddings[idx] is None:
embeddings[idx] = [0.0] * self.dimension
continue
for idx, result in zip(batch_indices, result_list):
embeddings[idx] = result.embedding
# 存入缓存
if self.cache_enabled:
cache_key = self._cache_key(texts[idx])
self._cache[cache_key] = result.embedding
# 🔥 调用进度回调
if progress_callback:
processed_count = min((i + 1) * current_batch_size, len(uncached_texts))
try:
progress_callback(processed_count, len(uncached_texts))
except Exception as e:
logger.warning(f"Progress callback error: {e}")
# 确保没有 None
return [e if e is not None else [0.0] * self.dimension for e in embeddings]
async def _process_batch_with_retry(
self,
batch: List[str],
indices: List[int],
cancel_check: Optional[callable] = None,
max_retries: Optional[int] = None
) -> List[EmbeddingResult]:
"""带重试机制的单批次处理"""
# 优先使用配置中的重试次数
actual_max_retries = max_retries or getattr(settings, "EMBEDDING_RETRY_MAX", 5)
client = await self._get_client()
for attempt in range(actual_max_retries):
if cancel_check and cancel_check():
raise asyncio.CancelledError("嵌入操作已取消")
async with self._semaphore:
try:
return await self._provider.embed_texts(batch, client=client)
except httpx.HTTPStatusError as e:
if e.response.status_code == 429 and attempt < actual_max_retries - 1:
# 429 限流,使用更保守的指数退避
# 第一次重试等待 5s, 第二次 10s, 第三次 20s...
wait_time = (2 ** attempt) * 5
logger.warning(f"Rate limited (429), retrying in {wait_time}s... (Attempt {attempt+1}/{actual_max_retries})")
await asyncio.sleep(wait_time)
continue
raise
except Exception as e:
if attempt < actual_max_retries - 1:
# 普通错误等待 2s
await asyncio.sleep(2)
continue
raise
return []
def clear_cache(self):
"""清空缓存"""
self._cache.clear()
@property
def cache_size(self) -> int:
"""缓存大小"""
return len(self._cache)