""" 嵌入模型服务 支持多种嵌入模型提供商: OpenAI, Azure, Ollama, Cohere, HuggingFace, Jina """ import asyncio import hashlib import logging from typing import List, Dict, Any, Optional from abc import ABC, abstractmethod from dataclasses import dataclass import httpx from app.core.config import settings logger = logging.getLogger(__name__) @dataclass class EmbeddingResult: """嵌入结果""" embedding: List[float] tokens_used: int model: str class EmbeddingProvider(ABC): """嵌入提供商基类""" @abstractmethod async def embed_text(self, text: str) -> EmbeddingResult: """嵌入单个文本""" pass @abstractmethod async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: """批量嵌入文本""" pass @property @abstractmethod def dimension(self) -> int: """嵌入向量维度""" pass class OpenAIEmbedding(EmbeddingProvider): """OpenAI 嵌入服务""" MODELS = { "text-embedding-3-small": 1536, "text-embedding-3-large": 3072, "text-embedding-ada-002": 1536, "Qwen3-Embedding-4B": 2560, } def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: str = "text-embedding-3-small", ): # 优先使用显性参数,其次使用 EMBEDDING_API_KEY,最后使用 LLM_API_KEY self.api_key = ( api_key or getattr(settings, "EMBEDDING_API_KEY", None) or settings.LLM_API_KEY ) # 优先使用显性参数,其次使用 EMBEDDING_BASE_URL,最后使用 OpenAI 默认地址 self.base_url = ( base_url or getattr(settings, "EMBEDDING_BASE_URL", None) or "https://api.openai.com/v1" ) self.model = model # 优先使用显式指定的维度,其次根据模型预定义,最后默认 1536 self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0) self._dimension = ( self._explicit_dimension if self._explicit_dimension > 0 else self.MODELS.get(model, 1536) ) @property def dimension(self) -> int: return self._dimension async def embed_text(self, text: str) -> EmbeddingResult: results = await self.embed_texts([text]) return results[0] async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: if not texts: return [] max_length = 8191 truncated_texts = [text[:max_length] for text in texts] headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } payload = { "model": self.model, "input": truncated_texts, } url = f"{self.base_url.rstrip('/')}/embeddings" async with httpx.AsyncClient(timeout=60) as client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() results = [] for item in data.get("data", []): results.append(EmbeddingResult( embedding=item["embedding"], tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts), model=self.model, )) return results class AzureOpenAIEmbedding(EmbeddingProvider): """ Azure OpenAI 嵌入服务 使用最新 API 版本 2024-10-21 (GA) 端点格式: https://.openai.azure.com/openai/deployments//embeddings?api-version=2024-10-21 """ MODELS = { "text-embedding-3-small": 1536, "text-embedding-3-large": 3072, "text-embedding-ada-002": 1536, "Qwen3-Embedding-4B": 2560, } # 最新的 GA API 版本 API_VERSION = "2024-10-21" def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: str = "text-embedding-3-small", ): self.api_key = api_key self.base_url = base_url or "https://your-resource.openai.azure.com" self.model = model # 优先使用项目配置中的显式维度 self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0) self._dimension = ( self._explicit_dimension if self._explicit_dimension > 0 else self.MODELS.get(model, 1536) ) @property def dimension(self) -> int: return self._dimension async def embed_text(self, text: str) -> EmbeddingResult: results = await self.embed_texts([text]) return results[0] async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: if not texts: return [] max_length = 8191 truncated_texts = [text[:max_length] for text in texts] headers = { "api-key": self.api_key, "Content-Type": "application/json", } payload = { "input": truncated_texts, } # Azure URL 格式 - 使用最新 API 版本 url = f"{self.base_url.rstrip('/')}/openai/deployments/{self.model}/embeddings?api-version={self.API_VERSION}" async with httpx.AsyncClient(timeout=60) as client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() results = [] for item in data.get("data", []): results.append(EmbeddingResult( embedding=item["embedding"], tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts), model=self.model, )) return results class OllamaEmbedding(EmbeddingProvider): """ Ollama 本地嵌入服务 使用新的 /api/embed 端点 (2024年起): - 支持批量嵌入 - 使用 'input' 参数(支持字符串或字符串数组) """ MODELS = { "nomic-embed-text": 768, "mxbai-embed-large": 1024, "all-minilm": 384, "snowflake-arctic-embed": 1024, "bge-m3": 1024, "qwen3-embedding": 1024, } def __init__( self, base_url: Optional[str] = None, model: str = "nomic-embed-text", ): self.base_url = base_url or "http://localhost:11434" self.model = model self._dimension = self.MODELS.get(model, 768) @property def dimension(self) -> int: return self._dimension async def embed_text(self, text: str) -> EmbeddingResult: results = await self.embed_texts([text]) return results[0] async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: if not texts: return [] # 新的 Ollama /api/embed 端点 url = f"{self.base_url.rstrip('/')}/api/embed" payload = { "model": self.model, "input": texts, # 新 API 使用 'input' 参数,支持批量 } async with httpx.AsyncClient(timeout=120) as client: response = await client.post(url, json=payload) response.raise_for_status() data = response.json() # 新 API 返回格式: {"embeddings": [[...], [...], ...]} embeddings = data.get("embeddings", []) results = [] for i, embedding in enumerate(embeddings): results.append(EmbeddingResult( embedding=embedding, tokens_used=len(texts[i]) // 4, model=self.model, )) return results class CohereEmbedding(EmbeddingProvider): """ Cohere 嵌入服务 使用新的 v2 API (2024年起): - 端点: https://api.cohere.com/v2/embed - 使用 'inputs' 参数替代 'texts' - 需要指定 'embedding_types' """ MODELS = { "embed-english-v3.0": 1024, "embed-multilingual-v3.0": 1024, "embed-english-light-v3.0": 384, "embed-multilingual-light-v3.0": 384, "embed-v4.0": 1024, # 最新模型 } def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: str = "embed-multilingual-v3.0", ): self.api_key = api_key # 新的 v2 API 端点 self.base_url = base_url or "https://api.cohere.com/v2" self.model = model self._dimension = self.MODELS.get(model, 1024) @property def dimension(self) -> int: return self._dimension async def embed_text(self, text: str) -> EmbeddingResult: results = await self.embed_texts([text]) return results[0] async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: if not texts: return [] headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } # v2 API 参数格式 payload = { "model": self.model, "inputs": texts, # v2 使用 'inputs' 而非 'texts' "input_type": "search_document", "embedding_types": ["float"], # v2 需要指定嵌入类型 } url = f"{self.base_url.rstrip('/')}/embed" async with httpx.AsyncClient(timeout=60) as client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() results = [] # v2 API 返回格式: {"embeddings": {"float": [[...], [...]]}, ...} embeddings_data = data.get("embeddings", {}) embeddings = embeddings_data.get("float", []) if isinstance(embeddings_data, dict) else embeddings_data for embedding in embeddings: results.append(EmbeddingResult( embedding=embedding, tokens_used=data.get("meta", {}).get("billed_units", {}).get("input_tokens", 0) // max(len(texts), 1), model=self.model, )) return results class HuggingFaceEmbedding(EmbeddingProvider): """ HuggingFace Inference Providers 嵌入服务 使用新的 Router 端点 (2025年起): https://router.huggingface.co/hf-inference/models/{model}/pipeline/feature-extraction """ MODELS = { "sentence-transformers/all-MiniLM-L6-v2": 384, "sentence-transformers/all-mpnet-base-v2": 768, "BAAI/bge-large-zh-v1.5": 1024, "BAAI/bge-m3": 1024, "BAAI/bge-small-en-v1.5": 384, "BAAI/bge-base-en-v1.5": 768, } def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: str = "BAAI/bge-m3", ): self.api_key = api_key # 新的 Router 端点 self.base_url = base_url or "https://router.huggingface.co" self.model = model self._dimension = self.MODELS.get(model, 1024) @property def dimension(self) -> int: return self._dimension async def embed_text(self, text: str) -> EmbeddingResult: results = await self.embed_texts([text]) return results[0] async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: if not texts: return [] headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } # 新的 HuggingFace Router URL 格式 # https://router.huggingface.co/hf-inference/models/{model}/pipeline/feature-extraction url = f"{self.base_url.rstrip('/')}/hf-inference/models/{self.model}/pipeline/feature-extraction" payload = { "inputs": texts, "options": { "wait_for_model": True, } } async with httpx.AsyncClient(timeout=120) as client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() results = [] # HuggingFace 返回格式: [[embedding1], [embedding2], ...] for embedding in data: # 有时候返回的是嵌套的列表 if isinstance(embedding, list) and len(embedding) > 0: if isinstance(embedding[0], list): # 取平均或第一个 embedding = embedding[0] results.append(EmbeddingResult( embedding=embedding, tokens_used=len(texts[len(results)]) // 4, model=self.model, )) return results class JinaEmbedding(EmbeddingProvider): """Jina AI 嵌入服务""" MODELS = { "jina-embeddings-v2-base-code": 768, "jina-embeddings-v2-base-en": 768, "jina-embeddings-v2-base-zh": 768, "jina-embeddings-v2-small-en": 512, } def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: str = "jina-embeddings-v2-base-code", ): self.api_key = api_key self.base_url = base_url or "https://api.jina.ai/v1" self.model = model self._dimension = self.MODELS.get(model, 768) @property def dimension(self) -> int: return self._dimension async def embed_text(self, text: str) -> EmbeddingResult: results = await self.embed_texts([text]) return results[0] async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: if not texts: return [] headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } payload = { "model": self.model, "input": texts, } url = f"{self.base_url.rstrip('/')}/embeddings" async with httpx.AsyncClient(timeout=60) as client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() results = [] for item in data.get("data", []): results.append(EmbeddingResult( embedding=item["embedding"], tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts), model=self.model, )) return results class QwenEmbedding(EmbeddingProvider): """Qwen 嵌入服务(基于阿里云 DashScope embeddings API)""" MODELS = { # DashScope Qwen 嵌入模型及其默认维度 "text-embedding-v4": 1024, # 支持维度: 2048, 1536, 1024(默认), 768, 512, 256, 128, 64 "text-embedding-v3": 1024, # 支持维度: 1024(默认), 768, 512, 256, 128, 64 "text-embedding-v2": 1536, # 支持维度: 1536 "Qwen3-Embedding-4B": 2560, } def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: str = "text-embedding-v4", ): # 优先使用显式传入的 api_key,其次使用 EMBEDDING_API_KEY/QWEN_API_KEY/LLM_API_KEY self.api_key = ( api_key or getattr(settings, "EMBEDDING_API_KEY", None) or getattr(settings, "QWEN_API_KEY", None) or settings.LLM_API_KEY ) # 🔥 API 密钥验证 if not self.api_key: raise ValueError( "Qwen embedding requires API key. " "Set EMBEDDING_API_KEY, QWEN_API_KEY or LLM_API_KEY environment variable." ) # DashScope 兼容 OpenAI 的 embeddings 端点 self.base_url = base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1" self.model = model # 优先使用显式指定的维度,其次根据模型预定义,最后默认 1024 self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0) self._dimension = ( self._explicit_dimension if self._explicit_dimension > 0 else self.MODELS.get(model, 1024) ) @property def dimension(self) -> int: return self._dimension async def embed_text(self, text: str) -> EmbeddingResult: results = await self.embed_texts([text]) return results[0] async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: if not texts: return [] # 与 OpenAI 接口保持一致的截断策略 max_length = 8191 truncated_texts = [text[:max_length] for text in texts] headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } payload = { "model": self.model, "input": truncated_texts, } url = f"{self.base_url.rstrip('/')}/embeddings" try: async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post( url, headers=headers, json=payload, ) response.raise_for_status() data = response.json() usage = data.get("usage", {}) or {} total_tokens = usage.get("total_tokens") or usage.get("prompt_tokens") or 0 results: List[EmbeddingResult] = [] for item in data.get("data", []): results.append(EmbeddingResult( embedding=item["embedding"], tokens_used=total_tokens // max(len(texts), 1), model=self.model, )) return results except httpx.HTTPStatusError as e: logger.error(f"Qwen embedding API error: {e.response.status_code} - {e.response.text}") raise RuntimeError(f"Qwen embedding API failed: {e.response.status_code}") from e except httpx.RequestError as e: logger.error(f"Qwen embedding network error: {e}") raise RuntimeError(f"Qwen embedding network error: {e}") from e except Exception as e: logger.error(f"Qwen embedding unexpected error: {e}") raise RuntimeError(f"Qwen embedding failed: {e}") from e class EmbeddingService: """ 嵌入服务 统一管理嵌入模型和缓存 支持的提供商: - openai: OpenAI 官方 - azure: Azure OpenAI - ollama: Ollama 本地 - cohere: Cohere - huggingface: HuggingFace Inference API - jina: Jina AI """ def __init__( self, provider: Optional[str] = None, model: Optional[str] = None, api_key: Optional[str] = None, base_url: Optional[str] = None, cache_enabled: bool = True, ): """ 初始化嵌入服务 Args: provider: 提供商 (openai, azure, ollama, cohere, huggingface, jina) model: 模型名称 api_key: API Key base_url: API Base URL cache_enabled: 是否启用缓存 """ self.cache_enabled = cache_enabled self._cache: Dict[str, List[float]] = {} # 确定提供商(保存原始值用于属性访问) self.provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai') self.model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small') self.api_key = api_key or getattr(settings, "EMBEDDING_API_KEY", None) self.base_url = base_url or getattr(settings, "EMBEDDING_BASE_URL", None) # 创建提供商实例 self._provider = self._create_provider( provider=self.provider, model=self.model, api_key=self.api_key, base_url=self.base_url, ) # 🔥 控制并发请求数 (RPS 限制) # 全局 RPS 限制为 30,由 4 个 gunicorn worker 共享 # 每个 worker 限制为 30/4 = 7 个并发请求,确保不触发限流 self._semaphore = asyncio.Semaphore(7) # 🔥 设置默认批次大小 (对于 remote 模型,用户要求为 10) is_remote = self.provider.lower() in ["openai", "qwen", "azure", "cohere", "jina", "huggingface"] self.batch_size = 10 if is_remote else 100 logger.info(f"Embedding service initialized with {self.provider}/{self.model} (Batch size: {self.batch_size})") def _create_provider( self, provider: str, model: str, api_key: Optional[str], base_url: Optional[str], ) -> EmbeddingProvider: """创建嵌入提供商实例""" provider = provider.lower() if provider == "ollama": return OllamaEmbedding(base_url=base_url, model=model) elif provider == "azure": return AzureOpenAIEmbedding(api_key=api_key, base_url=base_url, model=model) elif provider == "cohere": return CohereEmbedding(api_key=api_key, base_url=base_url, model=model) elif provider == "huggingface": return HuggingFaceEmbedding(api_key=api_key, base_url=base_url, model=model) elif provider == "jina": return JinaEmbedding(api_key=api_key, base_url=base_url, model=model) elif provider == "qwen": return QwenEmbedding(api_key=api_key, base_url=base_url, model=model) else: # 默认使用 OpenAI return OpenAIEmbedding(api_key=api_key, base_url=base_url, model=model) @property def dimension(self) -> int: """嵌入向量维度""" return self._provider.dimension def _cache_key(self, text: str) -> str: """生成缓存键""" return hashlib.sha256(text.encode()).hexdigest()[:32] async def embed(self, text: str) -> List[float]: """ 嵌入单个文本 Args: text: 文本内容 Returns: 嵌入向量 """ if not text or not text.strip(): return [0.0] * self.dimension # 检查缓存 if self.cache_enabled: cache_key = self._cache_key(text) if cache_key in self._cache: return self._cache[cache_key] result = await self._provider.embed_text(text) # 存入缓存 if self.cache_enabled: self._cache[cache_key] = result.embedding return result.embedding async def embed_batch( self, texts: List[str], batch_size: int = 100, show_progress: bool = False, progress_callback: Optional[callable] = None, cancel_check: Optional[callable] = None, ) -> List[List[float]]: """ 批量嵌入文本 Args: texts: 文本列表 batch_size: 批次大小 show_progress: 是否显示进度 progress_callback: 进度回调函数,接收 (processed, total) 参数 cancel_check: 取消检查函数,返回 True 表示应该取消 Returns: 嵌入向量列表 Raises: asyncio.CancelledError: 当 cancel_check 返回 True 时 """ if not texts: return [] embeddings = [] uncached_indices = [] uncached_texts = [] # 检查缓存 for i, text in enumerate(texts): if not text or not text.strip(): embeddings.append([0.0] * self.dimension) continue if self.cache_enabled: cache_key = self._cache_key(text) if cache_key in self._cache: embeddings.append(self._cache[cache_key]) continue embeddings.append(None) # 占位 uncached_indices.append(i) uncached_texts.append(text) # 批量处理未缓存的文本 if uncached_texts: tasks = [] current_batch_size = batch_size or self.batch_size for i in range(0, len(uncached_texts), current_batch_size): batch = uncached_texts[i:i + current_batch_size] batch_indices = uncached_indices[i:i + current_batch_size] tasks.append(self._process_batch_with_retry(batch, batch_indices, cancel_check)) # 🔥 并发执行所有批次任务 all_batch_results = await asyncio.gather(*tasks, return_exceptions=True) for i, result_list in enumerate(all_batch_results): batch_indices = uncached_indices[i * current_batch_size : (i + 1) * current_batch_size] if isinstance(result_list, Exception): logger.error(f"Batch processing failed: {result_list}") # 失败批次使用零向量 for idx in batch_indices: if embeddings[idx] is None: embeddings[idx] = [0.0] * self.dimension continue for idx, result in zip(batch_indices, result_list): embeddings[idx] = result.embedding # 存入缓存 if self.cache_enabled: cache_key = self._cache_key(texts[idx]) self._cache[cache_key] = result.embedding # 🔥 调用进度回调 if progress_callback: processed_count = min((i + 1) * current_batch_size, len(uncached_texts)) try: progress_callback(processed_count, len(uncached_texts)) except Exception as e: logger.warning(f"Progress callback error: {e}") # 确保没有 None return [e if e is not None else [0.0] * self.dimension for e in embeddings] async def _process_batch_with_retry( self, batch: List[str], indices: List[int], cancel_check: Optional[callable] = None, max_retries: int = 3 ) -> List[EmbeddingResult]: """带重试机制的单批次处理""" for attempt in range(max_retries): if cancel_check and cancel_check(): raise asyncio.CancelledError("嵌入操作已取消") async with self._semaphore: try: return await self._provider.embed_texts(batch) except httpx.HTTPStatusError as e: if e.response.status_code == 429 and attempt < max_retries - 1: # 429 限流,指数级退避 wait_time = (2 ** attempt) + 1 logger.warning(f"Rate limited (429), retrying in {wait_time}s... (Attempt {attempt+1}/{max_retries})") await asyncio.sleep(wait_time) continue raise except Exception as e: if attempt < max_retries - 1: await asyncio.sleep(1) continue raise return [] def clear_cache(self): """清空缓存""" self._cache.clear() @property def cache_size(self) -> int: """缓存大小""" return len(self._cache)