CodeReview/backend/app/api/v1/endpoints/embedding_config.py

397 lines
12 KiB
Python

"""
嵌入模型配置 API
独立于 LLM 配置,专门用于 RAG 系统的嵌入模型
使用 UserConfig.other_config 持久化存储
"""
import json
import uuid
from typing import Any, Optional, List
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api import deps
from app.models.user import User
from app.models.user_config import UserConfig
from app.core.config import settings
router = APIRouter()
# ============ Schemas ============
class EmbeddingProvider(BaseModel):
"""嵌入模型提供商"""
id: str
name: str
description: str
models: List[str]
requires_api_key: bool
default_model: str
class EmbeddingConfig(BaseModel):
"""嵌入模型配置"""
provider: str = Field(description="提供商: openai, ollama, azure, cohere, huggingface")
model: str = Field(description="模型名称")
api_key: Optional[str] = Field(default=None, description="API Key (如需要)")
base_url: Optional[str] = Field(default=None, description="自定义 API 端点")
dimensions: Optional[int] = Field(default=None, description="向量维度 (某些模型支持)")
batch_size: int = Field(default=100, description="批处理大小")
class EmbeddingConfigResponse(BaseModel):
"""配置响应"""
provider: str
model: str
base_url: Optional[str]
dimensions: int
batch_size: int
# 不返回 API Key
class TestEmbeddingRequest(BaseModel):
"""测试嵌入请求"""
provider: str
model: str
api_key: Optional[str] = None
base_url: Optional[str] = None
test_text: str = "这是一段测试文本,用于验证嵌入模型是否正常工作。"
class TestEmbeddingResponse(BaseModel):
"""测试嵌入响应"""
success: bool
message: str
dimensions: Optional[int] = None
sample_embedding: Optional[List[float]] = None # 前 5 个维度
latency_ms: Optional[int] = None
# ============ 提供商配置 ============
EMBEDDING_PROVIDERS: List[EmbeddingProvider] = [
EmbeddingProvider(
id="openai",
name="OpenAI",
description="OpenAI 官方嵌入模型,高质量、稳定",
models=[
"text-embedding-3-small",
"text-embedding-3-large",
"text-embedding-ada-002",
],
requires_api_key=True,
default_model="text-embedding-3-small",
),
EmbeddingProvider(
id="azure",
name="Azure OpenAI",
description="Azure 托管的 OpenAI 嵌入模型",
models=[
"text-embedding-3-small",
"text-embedding-3-large",
"text-embedding-ada-002",
],
requires_api_key=True,
default_model="text-embedding-3-small",
),
EmbeddingProvider(
id="ollama",
name="Ollama (本地)",
description="本地运行的开源嵌入模型 (使用 /api/embed 端点)",
models=[
"nomic-embed-text",
"mxbai-embed-large",
"all-minilm",
"snowflake-arctic-embed",
"bge-m3",
"qwen3-embedding",
],
requires_api_key=False,
default_model="nomic-embed-text",
),
EmbeddingProvider(
id="cohere",
name="Cohere",
description="Cohere Embed v2 API (api.cohere.com/v2)",
models=[
"embed-english-v3.0",
"embed-multilingual-v3.0",
"embed-english-light-v3.0",
"embed-multilingual-light-v3.0",
"embed-v4.0",
],
requires_api_key=True,
default_model="embed-multilingual-v3.0",
),
EmbeddingProvider(
id="huggingface",
name="HuggingFace",
description="HuggingFace Inference Providers (router.huggingface.co)",
models=[
"sentence-transformers/all-MiniLM-L6-v2",
"sentence-transformers/all-mpnet-base-v2",
"BAAI/bge-large-zh-v1.5",
"BAAI/bge-m3",
],
requires_api_key=True,
default_model="BAAI/bge-m3",
),
EmbeddingProvider(
id="jina",
name="Jina AI",
description="Jina AI 嵌入模型,代码嵌入效果好",
models=[
"jina-embeddings-v2-base-code",
"jina-embeddings-v2-base-en",
"jina-embeddings-v2-base-zh",
],
requires_api_key=True,
default_model="jina-embeddings-v2-base-code",
),
]
# ============ 数据库持久化存储 (异步) ============
EMBEDDING_CONFIG_KEY = "embedding_config"
async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> EmbeddingConfig:
"""从数据库获取嵌入配置(异步)"""
result = await db.execute(
select(UserConfig).where(UserConfig.user_id == user_id)
)
user_config = result.scalar_one_or_none()
if user_config and user_config.other_config:
try:
other_config = json.loads(user_config.other_config) if isinstance(user_config.other_config, str) else user_config.other_config
embedding_data = other_config.get(EMBEDDING_CONFIG_KEY)
if embedding_data:
return EmbeddingConfig(
provider=embedding_data.get("provider", settings.EMBEDDING_PROVIDER),
model=embedding_data.get("model", settings.EMBEDDING_MODEL),
api_key=embedding_data.get("api_key"),
base_url=embedding_data.get("base_url"),
dimensions=embedding_data.get("dimensions"),
batch_size=embedding_data.get("batch_size", 100),
)
except (json.JSONDecodeError, AttributeError):
pass
# 返回默认配置
return EmbeddingConfig(
provider=settings.EMBEDDING_PROVIDER,
model=settings.EMBEDDING_MODEL,
api_key=settings.LLM_API_KEY,
base_url=settings.LLM_BASE_URL,
batch_size=100,
)
async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: EmbeddingConfig) -> None:
"""保存嵌入配置到数据库(异步)"""
result = await db.execute(
select(UserConfig).where(UserConfig.user_id == user_id)
)
user_config = result.scalar_one_or_none()
# 准备嵌入配置数据
embedding_data = {
"provider": config.provider,
"model": config.model,
"api_key": config.api_key,
"base_url": config.base_url,
"dimensions": config.dimensions,
"batch_size": config.batch_size,
}
if user_config:
# 更新现有配置
try:
other_config = json.loads(user_config.other_config) if user_config.other_config else {}
except (json.JSONDecodeError, TypeError):
other_config = {}
other_config[EMBEDDING_CONFIG_KEY] = embedding_data
user_config.other_config = json.dumps(other_config)
else:
# 创建新配置
user_config = UserConfig(
id=str(uuid.uuid4()),
user_id=user_id,
llm_config="{}",
other_config=json.dumps({EMBEDDING_CONFIG_KEY: embedding_data}),
)
db.add(user_config)
await db.commit()
# ============ API Endpoints ============
@router.get("/providers", response_model=List[EmbeddingProvider])
async def list_embedding_providers(
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
获取可用的嵌入模型提供商列表
"""
return EMBEDDING_PROVIDERS
@router.get("/config", response_model=EmbeddingConfigResponse)
async def get_current_config(
db: AsyncSession = Depends(deps.get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
获取当前嵌入模型配置(从数据库读取)
"""
config = await get_embedding_config_from_db(db, current_user.id)
# 获取维度
dimensions = _get_model_dimensions(config.provider, config.model)
return EmbeddingConfigResponse(
provider=config.provider,
model=config.model,
base_url=config.base_url,
dimensions=dimensions,
batch_size=config.batch_size,
)
@router.put("/config")
async def update_config(
config: EmbeddingConfig,
db: AsyncSession = Depends(deps.get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
更新嵌入模型配置(持久化到数据库)
"""
# 验证提供商
provider_ids = [p.id for p in EMBEDDING_PROVIDERS]
if config.provider not in provider_ids:
raise HTTPException(status_code=400, detail=f"不支持的提供商: {config.provider}")
# 验证模型
provider = next((p for p in EMBEDDING_PROVIDERS if p.id == config.provider), None)
if provider and config.model not in provider.models:
raise HTTPException(status_code=400, detail=f"不支持的模型: {config.model}")
# 检查 API Key
if provider and provider.requires_api_key and not config.api_key:
raise HTTPException(status_code=400, detail=f"{config.provider} 需要 API Key")
# 保存到数据库
await save_embedding_config_to_db(db, current_user.id, config)
return {"message": "配置已保存", "provider": config.provider, "model": config.model}
@router.post("/test", response_model=TestEmbeddingResponse)
async def test_embedding(
request: TestEmbeddingRequest,
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
测试嵌入模型配置
"""
import time
try:
start_time = time.time()
# 创建临时嵌入服务
from app.services.rag.embeddings import EmbeddingService
service = EmbeddingService(
provider=request.provider,
model=request.model,
api_key=request.api_key,
base_url=request.base_url,
cache_enabled=False,
)
# 执行嵌入
embedding = await service.embed(request.test_text)
latency_ms = int((time.time() - start_time) * 1000)
return TestEmbeddingResponse(
success=True,
message=f"嵌入成功! 维度: {len(embedding)}",
dimensions=len(embedding),
sample_embedding=embedding[:5], # 返回前 5 维
latency_ms=latency_ms,
)
except Exception as e:
return TestEmbeddingResponse(
success=False,
message=f"嵌入失败: {str(e)}",
)
@router.get("/models/{provider}")
async def get_provider_models(
provider: str,
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
获取指定提供商的模型列表
"""
provider_info = next((p for p in EMBEDDING_PROVIDERS if p.id == provider), None)
if not provider_info:
raise HTTPException(status_code=404, detail=f"提供商不存在: {provider}")
return {
"provider": provider,
"models": provider_info.models,
"default_model": provider_info.default_model,
"requires_api_key": provider_info.requires_api_key,
}
def _get_model_dimensions(provider: str, model: str) -> int:
"""获取模型维度"""
dimensions_map = {
# OpenAI
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
# Ollama
"nomic-embed-text": 768,
"mxbai-embed-large": 1024,
"all-minilm": 384,
"snowflake-arctic-embed": 1024,
# Cohere
"embed-english-v3.0": 1024,
"embed-multilingual-v3.0": 1024,
"embed-english-light-v3.0": 384,
"embed-multilingual-light-v3.0": 384,
# HuggingFace
"sentence-transformers/all-MiniLM-L6-v2": 384,
"sentence-transformers/all-mpnet-base-v2": 768,
"BAAI/bge-large-zh-v1.5": 1024,
"BAAI/bge-m3": 1024,
# Jina
"jina-embeddings-v2-base-code": 768,
"jina-embeddings-v2-base-en": 768,
"jina-embeddings-v2-base-zh": 768,
}
return dimensions_map.get(model, 768)