From d827ab8b0380f92b1cafeb68369d2dd1a31ebe23 Mon Sep 17 00:00:00 2001 From: w1_liamby Date: Fri, 19 Dec 2025 15:14:39 +0800 Subject: [PATCH] =?UTF-8?q?[feat]=20=E6=B7=BB=E5=8A=A0=20qwen=20=E5=B5=8C?= =?UTF-8?q?=E5=85=A5=E6=A8=A1=E5=9E=8B=E6=8F=90=E4=BE=9B=E5=95=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../app/api/v1/endpoints/embedding_config.py | 19 ++++- backend/app/core/config.py | 2 +- backend/app/services/rag/embeddings.py | 79 +++++++++++++++++++ 3 files changed, 98 insertions(+), 2 deletions(-) diff --git a/backend/app/api/v1/endpoints/embedding_config.py b/backend/app/api/v1/endpoints/embedding_config.py index 541bf2a..156ea8e 100644 --- a/backend/app/api/v1/endpoints/embedding_config.py +++ b/backend/app/api/v1/endpoints/embedding_config.py @@ -35,7 +35,7 @@ class EmbeddingProvider(BaseModel): class EmbeddingConfig(BaseModel): """嵌入模型配置""" - provider: str = Field(description="提供商: openai, ollama, azure, cohere, huggingface") + provider: str = Field(description="提供商: openai, ollama, azure, cohere, huggingface, jina, qwen") model: str = Field(description="模型名称") api_key: Optional[str] = Field(default=None, description="API Key (如需要)") base_url: Optional[str] = Field(default=None, description="自定义 API 端点") @@ -152,6 +152,18 @@ EMBEDDING_PROVIDERS: List[EmbeddingProvider] = [ requires_api_key=True, default_model="jina-embeddings-v2-base-code", ), + EmbeddingProvider( + id="qwen", + name="Qwen (DashScope)", + description="阿里云 DashScope Qwen 嵌入模型,兼容 OpenAI embeddings 接口", + models=[ + "text-embedding-v4", + "text-embedding-v3", + "text-embedding-v2", + ], + requires_api_key=True, + default_model="text-embedding-v4", + ), ] @@ -397,6 +409,11 @@ def _get_model_dimensions(provider: str, model: str) -> int: "jina-embeddings-v2-base-code": 768, "jina-embeddings-v2-base-en": 768, "jina-embeddings-v2-base-zh": 768, + + # Qwen (DashScope) + "text-embedding-v4": 1024, # 支持维度: 2048, 1536, 1024(默认), 768, 512, 256, 128, 64 + "text-embedding-v3": 1024, # 支持维度: 1024(默认), 768, 512, 256, 128, 64 + "text-embedding-v2": 1536, # 支持维度: 1536 } return dimensions_map.get(model, 768) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 43d6385..bb7a2b8 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -80,7 +80,7 @@ class Settings(BaseSettings): # ============ Agent 模块配置 ============ # 嵌入模型配置(独立于 LLM 配置) - EMBEDDING_PROVIDER: str = "openai" # openai, azure, ollama, cohere, huggingface, jina + EMBEDDING_PROVIDER: str = "openai" # openai, azure, ollama, cohere, huggingface, jina, qwen EMBEDDING_MODEL: str = "text-embedding-3-small" EMBEDDING_API_KEY: Optional[str] = None # 嵌入模型专用 API Key(留空则使用 LLM_API_KEY) EMBEDDING_BASE_URL: Optional[str] = None # 嵌入模型专用 Base URL(留空使用提供商默认地址) diff --git a/backend/app/services/rag/embeddings.py b/backend/app/services/rag/embeddings.py index 1f8b910..75dd881 100644 --- a/backend/app/services/rag/embeddings.py +++ b/backend/app/services/rag/embeddings.py @@ -463,6 +463,82 @@ class JinaEmbedding(EmbeddingProvider): return results +class QwenEmbedding(EmbeddingProvider): + """Qwen 嵌入服务(基于阿里云 DashScope embeddings API)""" + + MODELS = { + # DashScope Qwen 嵌入模型及其默认维度 + "text-embedding-v4": 1024, # 支持维度: 2048, 1536, 1024(默认), 768, 512, 256, 128, 64 + "text-embedding-v3": 1024, # 支持维度: 1024(默认), 768, 512, 256, 128, 64 + "text-embedding-v2": 1536, # 支持维度: 1536 + } + + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + model: str = "text-embedding-v4", + ): + # 优先使用显式传入的 api_key,其次使用 EMBEDDING_API_KEY/QWEN_API_KEY/LLM_API_KEY + self.api_key = ( + api_key + or getattr(settings, "EMBEDDING_API_KEY", None) + or getattr(settings, "QWEN_API_KEY", None) + or settings.LLM_API_KEY + ) + # DashScope 兼容 OpenAI 的 embeddings 端点 + self.base_url = base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1" + self.model = model + self._dimension = self.MODELS.get(model, 1024) + + @property + def dimension(self) -> int: + return self._dimension + + async def embed_text(self, text: str) -> EmbeddingResult: + results = await self.embed_texts([text]) + return results[0] + + async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: + if not texts: + return [] + + # 与 OpenAI 接口保持一致的截断策略 + max_length = 8191 + truncated_texts = [text[:max_length] for text in texts] + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + payload = { + "model": self.model, + "input": truncated_texts, + "encoding_format": "float", + } + + url = f"{self.base_url.rstrip('/')}/embeddings" + + async with httpx.AsyncClient(timeout=60) as client: + response = await client.post(url, headers=headers, json=payload) + response.raise_for_status() + data = response.json() + + usage = data.get("usage", {}) or {} + total_tokens = usage.get("total_tokens") or usage.get("prompt_tokens") or 0 + + results: List[EmbeddingResult] = [] + for item in data.get("data", []): + results.append(EmbeddingResult( + embedding=item["embedding"], + tokens_used=total_tokens // max(len(texts), 1), + model=self.model, + )) + + return results + + class EmbeddingService: """ 嵌入服务 @@ -539,6 +615,9 @@ class EmbeddingService: elif provider == "jina": return JinaEmbedding(api_key=api_key, base_url=base_url, model=model) + elif provider == "qwen": + return QwenEmbedding(api_key=api_key, base_url=base_url, model=model) + else: # 默认使用 OpenAI return OpenAIEmbedding(api_key=api_key, base_url=base_url, model=model)