Merge pull request #83 from WilliamBy/v3-dev

[feat] 添加 qwen 嵌入模型提供商
This commit is contained in:
lintsinghua 2025-12-19 16:24:11 +08:00 committed by GitHub
commit 9eddef589a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 98 additions and 2 deletions

View File

@ -35,7 +35,7 @@ class EmbeddingProvider(BaseModel):
class EmbeddingConfig(BaseModel):
"""嵌入模型配置"""
provider: str = Field(description="提供商: openai, ollama, azure, cohere, huggingface")
provider: str = Field(description="提供商: openai, ollama, azure, cohere, huggingface, jina, qwen")
model: str = Field(description="模型名称")
api_key: Optional[str] = Field(default=None, description="API Key (如需要)")
base_url: Optional[str] = Field(default=None, description="自定义 API 端点")
@ -152,6 +152,18 @@ EMBEDDING_PROVIDERS: List[EmbeddingProvider] = [
requires_api_key=True,
default_model="jina-embeddings-v2-base-code",
),
EmbeddingProvider(
id="qwen",
name="Qwen (DashScope)",
description="阿里云 DashScope Qwen 嵌入模型,兼容 OpenAI embeddings 接口",
models=[
"text-embedding-v4",
"text-embedding-v3",
"text-embedding-v2",
],
requires_api_key=True,
default_model="text-embedding-v4",
),
]
@ -397,6 +409,11 @@ def _get_model_dimensions(provider: str, model: str) -> int:
"jina-embeddings-v2-base-code": 768,
"jina-embeddings-v2-base-en": 768,
"jina-embeddings-v2-base-zh": 768,
# Qwen (DashScope)
"text-embedding-v4": 1024, # 支持维度: 2048, 1536, 1024(默认), 768, 512, 256, 128, 64
"text-embedding-v3": 1024, # 支持维度: 1024(默认), 768, 512, 256, 128, 64
"text-embedding-v2": 1536, # 支持维度: 1536
}
return dimensions_map.get(model, 768)

View File

@ -80,7 +80,7 @@ class Settings(BaseSettings):
# ============ Agent 模块配置 ============
# 嵌入模型配置(独立于 LLM 配置)
EMBEDDING_PROVIDER: str = "openai" # openai, azure, ollama, cohere, huggingface, jina
EMBEDDING_PROVIDER: str = "openai" # openai, azure, ollama, cohere, huggingface, jina, qwen
EMBEDDING_MODEL: str = "text-embedding-3-small"
EMBEDDING_API_KEY: Optional[str] = None # 嵌入模型专用 API Key留空则使用 LLM_API_KEY
EMBEDDING_BASE_URL: Optional[str] = None # 嵌入模型专用 Base URL留空使用提供商默认地址

View File

@ -463,6 +463,82 @@ class JinaEmbedding(EmbeddingProvider):
return results
class QwenEmbedding(EmbeddingProvider):
"""Qwen 嵌入服务(基于阿里云 DashScope embeddings API"""
MODELS = {
# DashScope Qwen 嵌入模型及其默认维度
"text-embedding-v4": 1024, # 支持维度: 2048, 1536, 1024(默认), 768, 512, 256, 128, 64
"text-embedding-v3": 1024, # 支持维度: 1024(默认), 768, 512, 256, 128, 64
"text-embedding-v2": 1536, # 支持维度: 1536
}
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "text-embedding-v4",
):
# 优先使用显式传入的 api_key其次使用 EMBEDDING_API_KEY/QWEN_API_KEY/LLM_API_KEY
self.api_key = (
api_key
or getattr(settings, "EMBEDDING_API_KEY", None)
or getattr(settings, "QWEN_API_KEY", None)
or settings.LLM_API_KEY
)
# DashScope 兼容 OpenAI 的 embeddings 端点
self.base_url = base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1"
self.model = model
self._dimension = self.MODELS.get(model, 1024)
@property
def dimension(self) -> int:
return self._dimension
async def embed_text(self, text: str) -> EmbeddingResult:
results = await self.embed_texts([text])
return results[0]
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
if not texts:
return []
# 与 OpenAI 接口保持一致的截断策略
max_length = 8191
truncated_texts = [text[:max_length] for text in texts]
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"input": truncated_texts,
"encoding_format": "float",
}
url = f"{self.base_url.rstrip('/')}/embeddings"
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
usage = data.get("usage", {}) or {}
total_tokens = usage.get("total_tokens") or usage.get("prompt_tokens") or 0
results: List[EmbeddingResult] = []
for item in data.get("data", []):
results.append(EmbeddingResult(
embedding=item["embedding"],
tokens_used=total_tokens // max(len(texts), 1),
model=self.model,
))
return results
class EmbeddingService:
"""
嵌入服务
@ -539,6 +615,9 @@ class EmbeddingService:
elif provider == "jina":
return JinaEmbedding(api_key=api_key, base_url=base_url, model=model)
elif provider == "qwen":
return QwenEmbedding(api_key=api_key, base_url=base_url, model=model)
else:
# 默认使用 OpenAI
return OpenAIEmbedding(api_key=api_key, base_url=base_url, model=model)