352 lines
11 KiB
Python
352 lines
11 KiB
Python
"""
|
||
嵌入模型配置 API
|
||
独立于 LLM 配置,专门用于 RAG 系统的嵌入模型
|
||
使用 UserConfig.other_config 持久化存储
|
||
"""
|
||
|
||
import json
|
||
import uuid
|
||
from typing import Any, Optional, List
|
||
from fastapi import APIRouter, Depends, HTTPException
|
||
from pydantic import BaseModel, Field
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy.orm.attributes import flag_modified
|
||
|
||
from app.api import deps
|
||
from app.models.user import User
|
||
from app.models.user_config import UserConfig
|
||
from app.core.config import settings
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
# ============ Schemas ============
|
||
|
||
class EmbeddingProvider(BaseModel):
|
||
"""嵌入模型提供商"""
|
||
id: str
|
||
name: str
|
||
description: str
|
||
models: List[str]
|
||
requires_api_key: bool
|
||
default_model: str
|
||
|
||
|
||
class EmbeddingConfig(BaseModel):
|
||
"""嵌入模型配置"""
|
||
provider: str = Field(description="提供商: openai, ollama, azure, cohere, huggingface, jina, qwen")
|
||
model: str = Field(description="模型名称")
|
||
api_key: Optional[str] = Field(default=None, description="API Key (如需要)")
|
||
base_url: Optional[str] = Field(default=None, description="自定义 API 端点")
|
||
dimensions: Optional[int] = Field(default=None, description="向量维度 (某些模型支持)")
|
||
batch_size: int = Field(default=100, description="批处理大小")
|
||
|
||
|
||
class EmbeddingConfigResponse(BaseModel):
|
||
"""配置响应"""
|
||
provider: str
|
||
model: str
|
||
api_key: Optional[str] = None # 返回 API Key
|
||
base_url: Optional[str]
|
||
dimensions: int
|
||
batch_size: int
|
||
|
||
|
||
class TestEmbeddingRequest(BaseModel):
|
||
"""测试嵌入请求"""
|
||
provider: str
|
||
model: str
|
||
api_key: Optional[str] = None
|
||
base_url: Optional[str] = None
|
||
test_text: str = "这是一段测试文本,用于验证嵌入模型是否正常工作。"
|
||
|
||
|
||
class TestEmbeddingResponse(BaseModel):
|
||
"""测试嵌入响应"""
|
||
success: bool
|
||
message: str
|
||
dimensions: Optional[int] = None
|
||
sample_embedding: Optional[List[float]] = None # 前 5 个维度
|
||
latency_ms: Optional[int] = None
|
||
|
||
|
||
# ============ 提供商配置 ============
|
||
|
||
EMBEDDING_PROVIDERS: List[EmbeddingProvider] = [
|
||
EmbeddingProvider(
|
||
id="openai",
|
||
name="OpenAI (兼容 DeepSeek/Moonshot/智谱 等)",
|
||
description="OpenAI 官方或兼容 API,填写自定义端点可接入其他服务商",
|
||
models=[
|
||
"text-embedding-3-small",
|
||
"text-embedding-3-large",
|
||
"text-embedding-ada-002",
|
||
],
|
||
requires_api_key=True,
|
||
default_model="text-embedding-3-small",
|
||
),
|
||
EmbeddingProvider(
|
||
id="azure",
|
||
name="Azure OpenAI",
|
||
description="Azure 托管的 OpenAI 嵌入模型",
|
||
models=[
|
||
"text-embedding-3-small",
|
||
"text-embedding-3-large",
|
||
"text-embedding-ada-002",
|
||
],
|
||
requires_api_key=True,
|
||
default_model="text-embedding-3-small",
|
||
),
|
||
EmbeddingProvider(
|
||
id="ollama",
|
||
name="Ollama (本地)",
|
||
description="本地运行的开源嵌入模型 (使用 /api/embed 端点)",
|
||
models=[
|
||
"nomic-embed-text",
|
||
"mxbai-embed-large",
|
||
"all-minilm",
|
||
"snowflake-arctic-embed",
|
||
"bge-m3",
|
||
"qwen3-embedding",
|
||
],
|
||
requires_api_key=False,
|
||
default_model="nomic-embed-text",
|
||
),
|
||
EmbeddingProvider(
|
||
id="cohere",
|
||
name="Cohere",
|
||
description="Cohere Embed v2 API (api.cohere.com/v2)",
|
||
models=[
|
||
"embed-english-v3.0",
|
||
"embed-multilingual-v3.0",
|
||
"embed-english-light-v3.0",
|
||
"embed-multilingual-light-v3.0",
|
||
"embed-v4.0",
|
||
],
|
||
requires_api_key=True,
|
||
default_model="embed-multilingual-v3.0",
|
||
),
|
||
EmbeddingProvider(
|
||
id="huggingface",
|
||
name="HuggingFace",
|
||
description="HuggingFace Inference Providers (router.huggingface.co)",
|
||
models=[
|
||
"sentence-transformers/all-MiniLM-L6-v2",
|
||
"sentence-transformers/all-mpnet-base-v2",
|
||
"BAAI/bge-large-zh-v1.5",
|
||
"BAAI/bge-m3",
|
||
],
|
||
requires_api_key=True,
|
||
default_model="BAAI/bge-m3",
|
||
),
|
||
EmbeddingProvider(
|
||
id="jina",
|
||
name="Jina AI",
|
||
description="Jina AI 嵌入模型,代码嵌入效果好",
|
||
models=[
|
||
"jina-embeddings-v2-base-code",
|
||
"jina-embeddings-v2-base-en",
|
||
"jina-embeddings-v2-base-zh",
|
||
],
|
||
requires_api_key=True,
|
||
default_model="jina-embeddings-v2-base-code",
|
||
),
|
||
EmbeddingProvider(
|
||
id="qwen",
|
||
name="Qwen (DashScope)",
|
||
description="阿里云 DashScope Qwen 嵌入模型,兼容 OpenAI embeddings 接口",
|
||
models=[
|
||
"text-embedding-v4",
|
||
"text-embedding-v3",
|
||
"text-embedding-v2",
|
||
],
|
||
requires_api_key=True,
|
||
default_model="text-embedding-v4",
|
||
),
|
||
]
|
||
|
||
|
||
# ============ 数据库持久化存储 (异步) ============
|
||
|
||
EMBEDDING_CONFIG_KEY = "embedding_config"
|
||
|
||
|
||
def mask_api_key(key: Optional[str]) -> str:
|
||
"""部分遮盖API Key,显示前3位和后4位"""
|
||
if not key:
|
||
return ""
|
||
if len(key) <= 8:
|
||
return "***"
|
||
return f"{key[:3]}***{key[-4:]}"
|
||
|
||
|
||
async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> EmbeddingConfig:
|
||
"""从数据库获取嵌入配置(异步)"""
|
||
# 嵌入配置始终来自系统默认(.env),不再允许用户覆盖
|
||
print(f"[EmbeddingConfig] 返回系统默认嵌入配置(来自 .env)")
|
||
return EmbeddingConfig(
|
||
provider=settings.EMBEDDING_PROVIDER,
|
||
model=settings.EMBEDDING_MODEL,
|
||
api_key=mask_api_key(settings.EMBEDDING_API_KEY or settings.LLM_API_KEY),
|
||
base_url=settings.EMBEDDING_BASE_URL or settings.LLM_BASE_URL,
|
||
dimensions=settings.EMBEDDING_DIMENSION if settings.EMBEDDING_DIMENSION > 0 else None,
|
||
batch_size=100,
|
||
)
|
||
|
||
|
||
|
||
|
||
# ============ API Endpoints ============
|
||
|
||
@router.get("/providers", response_model=List[EmbeddingProvider])
|
||
async def list_embedding_providers(
|
||
current_user: User = Depends(deps.get_current_user),
|
||
) -> Any:
|
||
"""
|
||
获取可用的嵌入模型提供商列表
|
||
"""
|
||
return EMBEDDING_PROVIDERS
|
||
|
||
|
||
@router.get("/config", response_model=EmbeddingConfigResponse)
|
||
async def get_current_config(
|
||
db: AsyncSession = Depends(deps.get_db),
|
||
current_user: User = Depends(deps.get_current_user),
|
||
) -> Any:
|
||
"""
|
||
获取当前嵌入模型配置(从数据库读取)
|
||
"""
|
||
config = await get_embedding_config_from_db(db, current_user.id)
|
||
|
||
# 获取维度
|
||
dimensions = config.dimensions
|
||
if not dimensions or dimensions <= 0:
|
||
dimensions = _get_model_dimensions(config.provider, config.model)
|
||
|
||
return EmbeddingConfigResponse(
|
||
provider=config.provider,
|
||
model=config.model,
|
||
api_key=config.api_key,
|
||
base_url=config.base_url,
|
||
dimensions=dimensions,
|
||
batch_size=config.batch_size,
|
||
)
|
||
|
||
|
||
@router.put("/config")
|
||
async def update_config(
|
||
current_user: User = Depends(deps.get_current_user),
|
||
) -> Any:
|
||
"""
|
||
更新嵌入模型配置(已禁用,固定从 .env 读取)
|
||
"""
|
||
return {"message": "嵌入模型配置已锁定,请在 .env 文件中进行修改", "provider": settings.EMBEDDING_PROVIDER, "model": settings.EMBEDDING_MODEL}
|
||
|
||
|
||
@router.post("/test", response_model=TestEmbeddingResponse)
|
||
async def test_embedding(
|
||
request: TestEmbeddingRequest,
|
||
current_user: User = Depends(deps.get_current_user),
|
||
) -> Any:
|
||
"""测试当前系统的嵌入模型配置"""
|
||
import time
|
||
from app.services.rag.embeddings import EmbeddingService
|
||
|
||
try:
|
||
start_time = time.time()
|
||
|
||
# 始终使用系统的真实配置进行测试
|
||
# API Key 优先级:EMBEDDING_API_KEY > LLM_API_KEY
|
||
api_key = getattr(settings, "EMBEDDING_API_KEY", None) or settings.LLM_API_KEY
|
||
|
||
service = EmbeddingService(
|
||
provider=settings.EMBEDDING_PROVIDER,
|
||
model=settings.EMBEDDING_MODEL,
|
||
api_key=api_key,
|
||
base_url=settings.EMBEDDING_BASE_URL,
|
||
cache_enabled=False,
|
||
)
|
||
|
||
# 执行嵌入
|
||
embedding = await service.embed(request.test_text)
|
||
|
||
latency_ms = int((time.time() - start_time) * 1000)
|
||
|
||
return TestEmbeddingResponse(
|
||
success=True,
|
||
message=f"嵌入成功! 提供商: {settings.EMBEDDING_PROVIDER}, 维度: {len(embedding)}",
|
||
dimensions=len(embedding),
|
||
sample_embedding=embedding[:5], # 返回前 5 维
|
||
latency_ms=latency_ms,
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"❌ 嵌入测试失败: {str(e)}")
|
||
return TestEmbeddingResponse(
|
||
success=False,
|
||
message=f"嵌入失败: {str(e)}",
|
||
)
|
||
|
||
|
||
@router.get("/models/{provider}")
|
||
async def get_provider_models(
|
||
provider: str,
|
||
current_user: User = Depends(deps.get_current_user),
|
||
) -> Any:
|
||
"""
|
||
获取指定提供商的模型列表
|
||
"""
|
||
provider_info = next((p for p in EMBEDDING_PROVIDERS if p.id == provider), None)
|
||
|
||
if not provider_info:
|
||
raise HTTPException(status_code=404, detail=f"提供商不存在: {provider}")
|
||
|
||
return {
|
||
"provider": provider,
|
||
"models": provider_info.models,
|
||
"default_model": provider_info.default_model,
|
||
"requires_api_key": provider_info.requires_api_key,
|
||
}
|
||
|
||
|
||
def _get_model_dimensions(provider: str, model: str) -> int:
|
||
"""获取模型维度"""
|
||
dimensions_map = {
|
||
# OpenAI
|
||
"text-embedding-3-small": 1536,
|
||
"text-embedding-3-large": 3072,
|
||
"text-embedding-ada-002": 1536,
|
||
|
||
# Ollama
|
||
"nomic-embed-text": 768,
|
||
"mxbai-embed-large": 1024,
|
||
"all-minilm": 384,
|
||
"snowflake-arctic-embed": 1024,
|
||
|
||
# Cohere
|
||
"embed-english-v3.0": 1024,
|
||
"embed-multilingual-v3.0": 1024,
|
||
"embed-english-light-v3.0": 384,
|
||
"embed-multilingual-light-v3.0": 384,
|
||
|
||
# HuggingFace
|
||
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
||
"sentence-transformers/all-mpnet-base-v2": 768,
|
||
"BAAI/bge-large-zh-v1.5": 1024,
|
||
"BAAI/bge-m3": 1024,
|
||
|
||
# Jina
|
||
"jina-embeddings-v2-base-code": 768,
|
||
"jina-embeddings-v2-base-en": 768,
|
||
"jina-embeddings-v2-base-zh": 768,
|
||
|
||
# Qwen (DashScope)
|
||
"text-embedding-v4": 1024, # 支持维度: 2048, 1536, 1024(默认), 768, 512, 256, 128, 64
|
||
"text-embedding-v3": 1024, # 支持维度: 1024(默认), 768, 512, 256, 128, 64
|
||
"text-embedding-v2": 1536, # 支持维度: 1536
|
||
}
|
||
|
||
return dimensions_map.get(model, 768)
|
||
|