""" 嵌入模型服务 支持多种嵌入模型提供商: OpenAI, Azure, Ollama, Cohere, HuggingFace, Jina """ import asyncio import hashlib import logging from typing import List, Dict, Any, Optional from abc import ABC, abstractmethod from dataclasses import dataclass import httpx from app.core.config import settings logger = logging.getLogger(__name__) @dataclass class EmbeddingResult: """嵌入结果""" embedding: List[float] tokens_used: int model: str class EmbeddingProvider(ABC): """嵌入提供商基类""" @abstractmethod async def embed_text(self, text: str) -> EmbeddingResult: """嵌入单个文本""" pass @abstractmethod async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]: """批量嵌入文本""" pass @property @abstractmethod def dimension(self) -> int: """嵌入向量维度""" pass class OpenAIEmbedding(EmbeddingProvider): """OpenAI 嵌入服务""" MODELS = { "text-embedding-3-small": 1536, "text-embedding-3-large": 3072, "text-embedding-ada-002": 1536, "Qwen3-Embedding-4B": 2560, } def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: str = "text-embedding-3-small", ): # 优先使用显性参数,其次使用 EMBEDDING_API_KEY,最后使用 LLM_API_KEY self.api_key = ( api_key or getattr(settings, "EMBEDDING_API_KEY", None) or settings.LLM_API_KEY ) # 优先使用显性参数,其次使用 EMBEDDING_BASE_URL,最后使用 OpenAI 默认地址 self.base_url = ( base_url or getattr(settings, "EMBEDDING_BASE_URL", None) or "https://api.openai.com/v1" ) self.model = model # 优先使用显式指定的维度,其次根据模型预定义,最后默认 1536 self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0) self._dimension = ( self._explicit_dimension if self._explicit_dimension > 0 else self.MODELS.get(model, 1536) ) @property def dimension(self) -> int: return self._dimension async def embed_text(self, text: str) -> EmbeddingResult: results = await self.embed_texts([text]) return results[0] async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]: if not texts: return [] max_length = 8191 truncated_texts = [text[:max_length] for text in texts] headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } payload = { "model": self.model, "input": truncated_texts, } url = f"{self.base_url.rstrip('/')}/embeddings" if client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() else: async with httpx.AsyncClient(timeout=60) as client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() results = [] for item in data.get("data", []): results.append(EmbeddingResult( embedding=item["embedding"], tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts), model=self.model, )) return results class AzureOpenAIEmbedding(EmbeddingProvider): """ Azure OpenAI 嵌入服务 使用最新 API 版本 2024-10-21 (GA) 端点格式: https://.openai.azure.com/openai/deployments//embeddings?api-version=2024-10-21 """ MODELS = { "text-embedding-3-small": 1536, "text-embedding-3-large": 3072, "text-embedding-ada-002": 1536, "Qwen3-Embedding-4B": 2560, } # 最新的 GA API 版本 API_VERSION = "2024-10-21" def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: str = "text-embedding-3-small", ): self.api_key = api_key self.base_url = base_url or "https://your-resource.openai.azure.com" self.model = model # 优先使用项目配置中的显式维度 self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0) self._dimension = ( self._explicit_dimension if self._explicit_dimension > 0 else self.MODELS.get(model, 1536) ) @property def dimension(self) -> int: return self._dimension async def embed_text(self, text: str) -> EmbeddingResult: results = await self.embed_texts([text]) return results[0] async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]: if not texts: return [] max_length = 8191 truncated_texts = [text[:max_length] for text in texts] headers = { "api-key": self.api_key, "Content-Type": "application/json", } payload = { "input": truncated_texts, } # Azure URL 格式 - 使用最新 API 版本 url = f"{self.base_url.rstrip('/')}/openai/deployments/{self.model}/embeddings?api-version={self.API_VERSION}" if client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() else: async with httpx.AsyncClient(timeout=60) as client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() results = [] for item in data.get("data", []): results.append(EmbeddingResult( embedding=item["embedding"], tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts), model=self.model, )) return results class OllamaEmbedding(EmbeddingProvider): """ Ollama 本地嵌入服务 使用新的 /api/embed 端点 (2024年起): - 支持批量嵌入 - 使用 'input' 参数(支持字符串或字符串数组) """ MODELS = { "nomic-embed-text": 768, "mxbai-embed-large": 1024, "all-minilm": 384, "snowflake-arctic-embed": 1024, "bge-m3": 1024, "qwen3-embedding": 1024, } def __init__( self, base_url: Optional[str] = None, model: str = "nomic-embed-text", ): self.base_url = base_url or "http://localhost:11434" self.model = model self._dimension = self.MODELS.get(model, 768) @property def dimension(self) -> int: return self._dimension async def embed_text(self, text: str) -> EmbeddingResult: results = await self.embed_texts([text]) return results[0] async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]: if not texts: return [] # 新的 Ollama /api/embed 端点 url = f"{self.base_url.rstrip('/')}/api/embed" payload = { "model": self.model, "input": texts, # 新 API 使用 'input' 参数,支持批量 } if client: response = await client.post(url, json=payload) response.raise_for_status() data = response.json() else: async with httpx.AsyncClient(timeout=120) as client: response = await client.post(url, json=payload) response.raise_for_status() data = response.json() # 新 API 返回格式: {"embeddings": [[...], [...], ...]} embeddings = data.get("embeddings", []) results = [] for i, embedding in enumerate(embeddings): results.append(EmbeddingResult( embedding=embedding, tokens_used=len(texts[i]) // 4, model=self.model, )) return results class CohereEmbedding(EmbeddingProvider): """ Cohere 嵌入服务 使用新的 v2 API (2024年起): - 端点: https://api.cohere.com/v2/embed - 使用 'inputs' 参数替代 'texts' - 需要指定 'embedding_types' """ MODELS = { "embed-english-v3.0": 1024, "embed-multilingual-v3.0": 1024, "embed-english-light-v3.0": 384, "embed-multilingual-light-v3.0": 384, "embed-v4.0": 1024, # 最新模型 } def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: str = "embed-multilingual-v3.0", ): self.api_key = api_key # 新的 v2 API 端点 self.base_url = base_url or "https://api.cohere.com/v2" self.model = model self._dimension = self.MODELS.get(model, 1024) @property def dimension(self) -> int: return self._dimension async def embed_text(self, text: str) -> EmbeddingResult: results = await self.embed_texts([text]) return results[0] async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]: if not texts: return [] headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } # v2 API 参数格式 payload = { "model": self.model, "inputs": texts, # v2 使用 'inputs' 而非 'texts' "input_type": "search_document", "embedding_types": ["float"], # v2 需要指定嵌入类型 } url = f"{self.base_url.rstrip('/')}/embed" if client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() else: async with httpx.AsyncClient(timeout=60) as client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() results = [] # v2 API 返回格式: {"embeddings": {"float": [[...], [...]]}, ...} embeddings_data = data.get("embeddings", {}) embeddings = embeddings_data.get("float", []) if isinstance(embeddings_data, dict) else embeddings_data for embedding in embeddings: results.append(EmbeddingResult( embedding=embedding, tokens_used=data.get("meta", {}).get("billed_units", {}).get("input_tokens", 0) // max(len(texts), 1), model=self.model, )) return results class HuggingFaceEmbedding(EmbeddingProvider): """ HuggingFace Inference Providers 嵌入服务 使用新的 Router 端点 (2025年起): https://router.huggingface.co/hf-inference/models/{model}/pipeline/feature-extraction """ MODELS = { "sentence-transformers/all-MiniLM-L6-v2": 384, "sentence-transformers/all-mpnet-base-v2": 768, "BAAI/bge-large-zh-v1.5": 1024, "BAAI/bge-m3": 1024, "BAAI/bge-small-en-v1.5": 384, "BAAI/bge-base-en-v1.5": 768, } def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: str = "BAAI/bge-m3", ): self.api_key = api_key # 新的 Router 端点 self.base_url = base_url or "https://router.huggingface.co" self.model = model self._dimension = self.MODELS.get(model, 1024) @property def dimension(self) -> int: return self._dimension async def embed_text(self, text: str) -> EmbeddingResult: results = await self.embed_texts([text]) return results[0] async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]: if not texts: return [] headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } # 新的 HuggingFace Router URL 格式 url = f"{self.base_url.rstrip('/')}/hf-inference/models/{self.model}/pipeline/feature-extraction" payload = { "inputs": texts, "options": { "wait_for_model": True, } } if client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() else: async with httpx.AsyncClient(timeout=120) as client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() results = [] for embedding in data: if isinstance(embedding, list) and len(embedding) > 0: if isinstance(embedding[0], list): embedding = embedding[0] results.append(EmbeddingResult( embedding=embedding, tokens_used=len(texts[len(results)]) // 4, model=self.model, )) return results class JinaEmbedding(EmbeddingProvider): """Jina AI 嵌入服务""" MODELS = { "jina-embeddings-v2-base-code": 768, "jina-embeddings-v2-base-en": 768, "jina-embeddings-v2-base-zh": 768, "jina-embeddings-v2-small-en": 512, } def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: str = "jina-embeddings-v2-base-code", ): self.api_key = api_key self.base_url = base_url or "https://api.jina.ai/v1" self.model = model self._dimension = self.MODELS.get(model, 768) @property def dimension(self) -> int: return self._dimension async def embed_text(self, text: str) -> EmbeddingResult: results = await self.embed_texts([text]) return results[0] async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]: if not texts: return [] headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } payload = { "model": self.model, "input": texts, } url = f"{self.base_url.rstrip('/')}/embeddings" if client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() else: async with httpx.AsyncClient(timeout=60) as client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() results = [] for item in data.get("data", []): results.append(EmbeddingResult( embedding=item["embedding"], tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts), model=self.model, )) return results class QwenEmbedding(EmbeddingProvider): """Qwen 嵌入服务(基于阿里云 DashScope embeddings API)""" MODELS = { # DashScope Qwen 嵌入模型及其默认维度 "text-embedding-v4": 1024, # 支持维度: 2048, 1536, 1024(默认), 768, 512, 256, 128, 64 "text-embedding-v3": 1024, # 支持维度: 1024(默认), 768, 512, 256, 128, 64 "text-embedding-v2": 1536, # 支持维度: 1536 "Qwen3-Embedding-4B": 2560, } def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: str = "text-embedding-v4", ): # 优先使用显式传入的 api_key,其次使用 EMBEDDING_API_KEY/QWEN_API_KEY/LLM_API_KEY self.api_key = ( api_key or getattr(settings, "EMBEDDING_API_KEY", None) or getattr(settings, "QWEN_API_KEY", None) or settings.LLM_API_KEY ) # 🔥 API 密钥验证 if not self.api_key: raise ValueError( "Qwen embedding requires API key. " "Set EMBEDDING_API_KEY, QWEN_API_KEY or LLM_API_KEY environment variable." ) # DashScope 兼容 OpenAI 的 embeddings 端点 self.base_url = base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1" self.model = model # 优先使用显式指定的维度,其次根据模型预定义,最后默认 1024 self._explicit_dimension = getattr(settings, "EMBEDDING_DIMENSION", 0) self._dimension = ( self._explicit_dimension if self._explicit_dimension > 0 else self.MODELS.get(model, 1024) ) @property def dimension(self) -> int: return self._dimension async def embed_text(self, text: str) -> EmbeddingResult: results = await self.embed_texts([text]) return results[0] async def embed_texts(self, texts: List[str], client: Optional[httpx.AsyncClient] = None) -> List[EmbeddingResult]: if not texts: return [] # 与 OpenAI 接口保持一致的截断策略 max_length = 8191 truncated_texts = [text[:max_length] for text in texts] headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } payload = { "model": self.model, "input": truncated_texts, } url = f"{self.base_url.rstrip('/')}/embeddings" try: if client: response = await client.post(url, headers=headers, json=payload) else: async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() data = response.json() usage = data.get("usage", {}) or {} total_tokens = usage.get("total_tokens") or usage.get("prompt_tokens") or 0 results: List[EmbeddingResult] = [] for item in data.get("data", []): results.append(EmbeddingResult( embedding=item["embedding"], tokens_used=total_tokens // max(len(texts), 1), model=self.model, )) return results except httpx.HTTPStatusError as e: logger.error(f"Qwen embedding API error: {e.response.status_code} - {e.response.text}") raise except httpx.RequestError as e: logger.error(f"Qwen embedding network error: {e}") raise RuntimeError(f"Qwen embedding network error: {e}") from e except Exception as e: logger.error(f"Qwen embedding unexpected error: {e}") raise RuntimeError(f"Qwen embedding failed: {e}") from e class EmbeddingService: """ 嵌入服务 统一管理嵌入模型和缓存 支持的提供商: - openai: OpenAI 官方 - azure: Azure OpenAI - ollama: Ollama 本地 - cohere: Cohere - huggingface: HuggingFace Inference API - jina: Jina AI """ def __init__( self, provider: Optional[str] = None, model: Optional[str] = None, api_key: Optional[str] = None, base_url: Optional[str] = None, cache_enabled: bool = True, ): """ 初始化嵌入服务 Args: provider: 提供商 (openai, azure, ollama, cohere, huggingface, jina) model: 模型名称 api_key: API Key base_url: API Base URL cache_enabled: 是否启用缓存 """ self.cache_enabled = cache_enabled self._cache: Dict[str, List[float]] = {} # 确定提供商(保存原始值用于属性访问) self.provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai') self.model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small') self.api_key = api_key or getattr(settings, "EMBEDDING_API_KEY", None) self.base_url = base_url or getattr(settings, "EMBEDDING_BASE_URL", None) # 创建提供商实例 self._provider = self._create_provider( provider=self.provider, model=self.model, api_key=self.api_key, base_url=self.base_url, ) # 🔥 控制并发请求数 (RPS 限制) is_remote = self.provider.lower() in ["openai", "qwen", "azure", "cohere", "jina", "huggingface"] self.concurrency = getattr(settings, "EMBEDDING_CONCURRENCY", 2 if is_remote else 10) self._semaphore = asyncio.Semaphore(self.concurrency) # 🔥 设置默认批次大小 (DashScope text-embedding-v4 限制为 10) self.batch_size = 10 if is_remote else 100 # 🔥 共享 HTTP 客户端 self._client: Optional[httpx.AsyncClient] = None logger.info(f"Embedding service initialized with {self.provider}/{self.model} (Concurrency: {self.concurrency}, Batch size: {self.batch_size})") async def _get_client(self) -> httpx.AsyncClient: """获取或创建共享的 AsyncClient""" if self._client is None or self._client.is_closed: self._client = httpx.AsyncClient(timeout=120.0) return self._client async def close(self): """关闭共享客户端""" if self._client and not self._client.is_closed: await self._client.aclose() self._client = None def _create_provider( self, provider: str, model: str, api_key: Optional[str], base_url: Optional[str], ) -> EmbeddingProvider: """创建嵌入提供商实例""" provider = provider.lower() if provider == "ollama": return OllamaEmbedding(base_url=base_url, model=model) elif provider == "azure": return AzureOpenAIEmbedding(api_key=api_key, base_url=base_url, model=model) elif provider == "cohere": return CohereEmbedding(api_key=api_key, base_url=base_url, model=model) elif provider == "huggingface": return HuggingFaceEmbedding(api_key=api_key, base_url=base_url, model=model) elif provider == "jina": return JinaEmbedding(api_key=api_key, base_url=base_url, model=model) elif provider == "qwen": return QwenEmbedding(api_key=api_key, base_url=base_url, model=model) else: # 默认使用 OpenAI return OpenAIEmbedding(api_key=api_key, base_url=base_url, model=model) @property def dimension(self) -> int: """嵌入向量维度""" return self._provider.dimension def _cache_key(self, text: str) -> str: """生成缓存键""" return hashlib.sha256(text.encode()).hexdigest()[:32] async def embed(self, text: str) -> List[float]: """ 嵌入单个文本 Args: text: 文本内容 Returns: 嵌入向量 """ if not text or not text.strip(): return [0.0] * self.dimension # 检查缓存 if self.cache_enabled: cache_key = self._cache_key(text) if cache_key in self._cache: return self._cache[cache_key] result = await self._provider.embed_text(text) # 存入缓存 if self.cache_enabled: self._cache[cache_key] = result.embedding return result.embedding async def embed_batch( self, texts: List[str], batch_size: int = 100, show_progress: bool = False, progress_callback: Optional[callable] = None, cancel_check: Optional[callable] = None, ) -> List[List[float]]: """ 批量嵌入文本 Args: texts: 文本列表 batch_size: 批次大小 show_progress: 是否显示进度 progress_callback: 进度回调函数,接收 (processed, total) 参数 cancel_check: 取消检查函数,返回 True 表示应该取消 Returns: 嵌入向量列表 Raises: asyncio.CancelledError: 当 cancel_check 返回 True 时 """ if not texts: return [] embeddings = [] uncached_indices = [] uncached_texts = [] # 检查缓存 for i, text in enumerate(texts): if not text or not text.strip(): embeddings.append([0.0] * self.dimension) continue if self.cache_enabled: cache_key = self._cache_key(text) if cache_key in self._cache: embeddings.append(self._cache[cache_key]) continue embeddings.append(None) # 占位 uncached_indices.append(i) uncached_texts.append(text) # 批量处理未缓存的文本 if uncached_texts: tasks = [] current_batch_size = batch_size or self.batch_size for i in range(0, len(uncached_texts), current_batch_size): batch = uncached_texts[i:i + current_batch_size] batch_indices = uncached_indices[i:i + current_batch_size] tasks.append(self._process_batch_with_retry(batch, batch_indices, cancel_check)) # 🔥 并发执行所有批次任务 all_batch_results = await asyncio.gather(*tasks, return_exceptions=True) for i, result_list in enumerate(all_batch_results): batch_indices = uncached_indices[i * current_batch_size : (i + 1) * current_batch_size] if isinstance(result_list, Exception): logger.error(f"Batch processing failed: {result_list}") # 失败批次使用零向量 for idx in batch_indices: if embeddings[idx] is None: embeddings[idx] = [0.0] * self.dimension continue for idx, result in zip(batch_indices, result_list): embeddings[idx] = result.embedding # 存入缓存 if self.cache_enabled: cache_key = self._cache_key(texts[idx]) self._cache[cache_key] = result.embedding # 🔥 调用进度回调 if progress_callback: processed_count = min((i + 1) * current_batch_size, len(uncached_texts)) try: progress_callback(processed_count, len(uncached_texts)) except Exception as e: logger.warning(f"Progress callback error: {e}") # 确保没有 None return [e if e is not None else [0.0] * self.dimension for e in embeddings] async def _process_batch_with_retry( self, batch: List[str], indices: List[int], cancel_check: Optional[callable] = None, max_retries: Optional[int] = None ) -> List[EmbeddingResult]: """带重试机制的单批次处理""" # 优先使用配置中的重试次数 actual_max_retries = max_retries or getattr(settings, "EMBEDDING_RETRY_MAX", 5) client = await self._get_client() for attempt in range(actual_max_retries): if cancel_check and cancel_check(): raise asyncio.CancelledError("嵌入操作已取消") async with self._semaphore: try: return await self._provider.embed_texts(batch, client=client) except httpx.HTTPStatusError as e: if e.response.status_code == 429 and attempt < actual_max_retries - 1: # 429 限流,使用更保守的指数退避 # 第一次重试等待 5s, 第二次 10s, 第三次 20s... wait_time = (2 ** attempt) * 5 logger.warning(f"Rate limited (429), retrying in {wait_time}s... (Attempt {attempt+1}/{actual_max_retries})") await asyncio.sleep(wait_time) continue raise except Exception as e: if attempt < actual_max_retries - 1: # 普通错误等待 2s await asyncio.sleep(2) continue raise return [] def clear_cache(self): """清空缓存""" self._cache.clear() @property def cache_size(self) -> int: """缓存大小""" return len(self._cache)