commit
9eddef589a
|
|
@ -35,7 +35,7 @@ class EmbeddingProvider(BaseModel):
|
|||
|
||||
class EmbeddingConfig(BaseModel):
|
||||
"""嵌入模型配置"""
|
||||
provider: str = Field(description="提供商: openai, ollama, azure, cohere, huggingface")
|
||||
provider: str = Field(description="提供商: openai, ollama, azure, cohere, huggingface, jina, qwen")
|
||||
model: str = Field(description="模型名称")
|
||||
api_key: Optional[str] = Field(default=None, description="API Key (如需要)")
|
||||
base_url: Optional[str] = Field(default=None, description="自定义 API 端点")
|
||||
|
|
@ -152,6 +152,18 @@ EMBEDDING_PROVIDERS: List[EmbeddingProvider] = [
|
|||
requires_api_key=True,
|
||||
default_model="jina-embeddings-v2-base-code",
|
||||
),
|
||||
EmbeddingProvider(
|
||||
id="qwen",
|
||||
name="Qwen (DashScope)",
|
||||
description="阿里云 DashScope Qwen 嵌入模型,兼容 OpenAI embeddings 接口",
|
||||
models=[
|
||||
"text-embedding-v4",
|
||||
"text-embedding-v3",
|
||||
"text-embedding-v2",
|
||||
],
|
||||
requires_api_key=True,
|
||||
default_model="text-embedding-v4",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -397,6 +409,11 @@ def _get_model_dimensions(provider: str, model: str) -> int:
|
|||
"jina-embeddings-v2-base-code": 768,
|
||||
"jina-embeddings-v2-base-en": 768,
|
||||
"jina-embeddings-v2-base-zh": 768,
|
||||
|
||||
# Qwen (DashScope)
|
||||
"text-embedding-v4": 1024, # 支持维度: 2048, 1536, 1024(默认), 768, 512, 256, 128, 64
|
||||
"text-embedding-v3": 1024, # 支持维度: 1024(默认), 768, 512, 256, 128, 64
|
||||
"text-embedding-v2": 1536, # 支持维度: 1536
|
||||
}
|
||||
|
||||
return dimensions_map.get(model, 768)
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ class Settings(BaseSettings):
|
|||
# ============ Agent 模块配置 ============
|
||||
|
||||
# 嵌入模型配置(独立于 LLM 配置)
|
||||
EMBEDDING_PROVIDER: str = "openai" # openai, azure, ollama, cohere, huggingface, jina
|
||||
EMBEDDING_PROVIDER: str = "openai" # openai, azure, ollama, cohere, huggingface, jina, qwen
|
||||
EMBEDDING_MODEL: str = "text-embedding-3-small"
|
||||
EMBEDDING_API_KEY: Optional[str] = None # 嵌入模型专用 API Key(留空则使用 LLM_API_KEY)
|
||||
EMBEDDING_BASE_URL: Optional[str] = None # 嵌入模型专用 Base URL(留空使用提供商默认地址)
|
||||
|
|
|
|||
|
|
@ -463,6 +463,82 @@ class JinaEmbedding(EmbeddingProvider):
|
|||
return results
|
||||
|
||||
|
||||
class QwenEmbedding(EmbeddingProvider):
|
||||
"""Qwen 嵌入服务(基于阿里云 DashScope embeddings API)"""
|
||||
|
||||
MODELS = {
|
||||
# DashScope Qwen 嵌入模型及其默认维度
|
||||
"text-embedding-v4": 1024, # 支持维度: 2048, 1536, 1024(默认), 768, 512, 256, 128, 64
|
||||
"text-embedding-v3": 1024, # 支持维度: 1024(默认), 768, 512, 256, 128, 64
|
||||
"text-embedding-v2": 1536, # 支持维度: 1536
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
model: str = "text-embedding-v4",
|
||||
):
|
||||
# 优先使用显式传入的 api_key,其次使用 EMBEDDING_API_KEY/QWEN_API_KEY/LLM_API_KEY
|
||||
self.api_key = (
|
||||
api_key
|
||||
or getattr(settings, "EMBEDDING_API_KEY", None)
|
||||
or getattr(settings, "QWEN_API_KEY", None)
|
||||
or settings.LLM_API_KEY
|
||||
)
|
||||
# DashScope 兼容 OpenAI 的 embeddings 端点
|
||||
self.base_url = base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
self.model = model
|
||||
self._dimension = self.MODELS.get(model, 1024)
|
||||
|
||||
@property
|
||||
def dimension(self) -> int:
|
||||
return self._dimension
|
||||
|
||||
async def embed_text(self, text: str) -> EmbeddingResult:
|
||||
results = await self.embed_texts([text])
|
||||
return results[0]
|
||||
|
||||
async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# 与 OpenAI 接口保持一致的截断策略
|
||||
max_length = 8191
|
||||
truncated_texts = [text[:max_length] for text in texts]
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": truncated_texts,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
url = f"{self.base_url.rstrip('/')}/embeddings"
|
||||
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
response = await client.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
usage = data.get("usage", {}) or {}
|
||||
total_tokens = usage.get("total_tokens") or usage.get("prompt_tokens") or 0
|
||||
|
||||
results: List[EmbeddingResult] = []
|
||||
for item in data.get("data", []):
|
||||
results.append(EmbeddingResult(
|
||||
embedding=item["embedding"],
|
||||
tokens_used=total_tokens // max(len(texts), 1),
|
||||
model=self.model,
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
"""
|
||||
嵌入服务
|
||||
|
|
@ -539,6 +615,9 @@ class EmbeddingService:
|
|||
elif provider == "jina":
|
||||
return JinaEmbedding(api_key=api_key, base_url=base_url, model=model)
|
||||
|
||||
elif provider == "qwen":
|
||||
return QwenEmbedding(api_key=api_key, base_url=base_url, model=model)
|
||||
|
||||
else:
|
||||
# 默认使用 OpenAI
|
||||
return OpenAIEmbedding(api_key=api_key, base_url=base_url, model=model)
|
||||
|
|
|
|||
Loading…
Reference in New Issue