CodeReview/backend/app/api/v1/endpoints/embedding_config.py

352 lines
11 KiB
Python
Raw Normal View History

"""
嵌入模型配置 API
独立于 LLM 配置专门用于 RAG 系统的嵌入模型
使用 UserConfig.other_config 持久化存储
"""
import json
import uuid
from typing import Any, Optional, List
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.api import deps
from app.models.user import User
from app.models.user_config import UserConfig
from app.core.config import settings
router = APIRouter()
# ============ Schemas ============
class EmbeddingProvider(BaseModel):
"""嵌入模型提供商"""
id: str
name: str
description: str
models: List[str]
requires_api_key: bool
default_model: str
class EmbeddingConfig(BaseModel):
"""嵌入模型配置"""
provider: str = Field(description="提供商: openai, ollama, azure, cohere, huggingface, jina, qwen")
model: str = Field(description="模型名称")
api_key: Optional[str] = Field(default=None, description="API Key (如需要)")
base_url: Optional[str] = Field(default=None, description="自定义 API 端点")
dimensions: Optional[int] = Field(default=None, description="向量维度 (某些模型支持)")
batch_size: int = Field(default=100, description="批处理大小")
class EmbeddingConfigResponse(BaseModel):
"""配置响应"""
provider: str
model: str
api_key: Optional[str] = None # 返回 API Key
base_url: Optional[str]
dimensions: int
batch_size: int
class TestEmbeddingRequest(BaseModel):
"""测试嵌入请求"""
provider: str
model: str
api_key: Optional[str] = None
base_url: Optional[str] = None
test_text: str = "这是一段测试文本,用于验证嵌入模型是否正常工作。"
class TestEmbeddingResponse(BaseModel):
"""测试嵌入响应"""
success: bool
message: str
dimensions: Optional[int] = None
sample_embedding: Optional[List[float]] = None # 前 5 个维度
latency_ms: Optional[int] = None
# ============ 提供商配置 ============
EMBEDDING_PROVIDERS: List[EmbeddingProvider] = [
EmbeddingProvider(
id="openai",
name="OpenAI (兼容 DeepSeek/Moonshot/智谱 等)",
description="OpenAI 官方或兼容 API填写自定义端点可接入其他服务商",
models=[
"text-embedding-3-small",
"text-embedding-3-large",
"text-embedding-ada-002",
],
requires_api_key=True,
default_model="text-embedding-3-small",
),
EmbeddingProvider(
id="azure",
name="Azure OpenAI",
description="Azure 托管的 OpenAI 嵌入模型",
models=[
"text-embedding-3-small",
"text-embedding-3-large",
"text-embedding-ada-002",
],
requires_api_key=True,
default_model="text-embedding-3-small",
),
EmbeddingProvider(
id="ollama",
name="Ollama (本地)",
description="本地运行的开源嵌入模型 (使用 /api/embed 端点)",
models=[
"nomic-embed-text",
"mxbai-embed-large",
"all-minilm",
"snowflake-arctic-embed",
"bge-m3",
"qwen3-embedding",
],
requires_api_key=False,
default_model="nomic-embed-text",
),
EmbeddingProvider(
id="cohere",
name="Cohere",
description="Cohere Embed v2 API (api.cohere.com/v2)",
models=[
"embed-english-v3.0",
"embed-multilingual-v3.0",
"embed-english-light-v3.0",
"embed-multilingual-light-v3.0",
"embed-v4.0",
],
requires_api_key=True,
default_model="embed-multilingual-v3.0",
),
EmbeddingProvider(
id="huggingface",
name="HuggingFace",
description="HuggingFace Inference Providers (router.huggingface.co)",
models=[
"sentence-transformers/all-MiniLM-L6-v2",
"sentence-transformers/all-mpnet-base-v2",
"BAAI/bge-large-zh-v1.5",
"BAAI/bge-m3",
],
requires_api_key=True,
default_model="BAAI/bge-m3",
),
EmbeddingProvider(
id="jina",
name="Jina AI",
description="Jina AI 嵌入模型,代码嵌入效果好",
models=[
"jina-embeddings-v2-base-code",
"jina-embeddings-v2-base-en",
"jina-embeddings-v2-base-zh",
],
requires_api_key=True,
default_model="jina-embeddings-v2-base-code",
),
EmbeddingProvider(
id="qwen",
name="Qwen (DashScope)",
description="阿里云 DashScope Qwen 嵌入模型,兼容 OpenAI embeddings 接口",
models=[
"text-embedding-v4",
"text-embedding-v3",
"text-embedding-v2",
],
requires_api_key=True,
default_model="text-embedding-v4",
),
]
# ============ 数据库持久化存储 (异步) ============
EMBEDDING_CONFIG_KEY = "embedding_config"
def mask_api_key(key: Optional[str]) -> str:
"""部分遮盖API Key显示前3位和后4位"""
if not key:
return ""
if len(key) <= 8:
return "***"
return f"{key[:3]}***{key[-4:]}"
async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> EmbeddingConfig:
"""从数据库获取嵌入配置(异步)"""
# 嵌入配置始终来自系统默认(.env不再允许用户覆盖
print(f"[EmbeddingConfig] 返回系统默认嵌入配置(来自 .env")
return EmbeddingConfig(
provider=settings.EMBEDDING_PROVIDER,
model=settings.EMBEDDING_MODEL,
api_key=mask_api_key(settings.EMBEDDING_API_KEY or settings.LLM_API_KEY),
base_url=settings.EMBEDDING_BASE_URL or settings.LLM_BASE_URL,
dimensions=settings.EMBEDDING_DIMENSION if settings.EMBEDDING_DIMENSION > 0 else None,
batch_size=100,
)
# ============ API Endpoints ============
@router.get("/providers", response_model=List[EmbeddingProvider])
async def list_embedding_providers(
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
获取可用的嵌入模型提供商列表
"""
return EMBEDDING_PROVIDERS
@router.get("/config", response_model=EmbeddingConfigResponse)
async def get_current_config(
db: AsyncSession = Depends(deps.get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
获取当前嵌入模型配置从数据库读取
"""
config = await get_embedding_config_from_db(db, current_user.id)
# 获取维度
dimensions = config.dimensions
if not dimensions or dimensions <= 0:
dimensions = _get_model_dimensions(config.provider, config.model)
return EmbeddingConfigResponse(
provider=config.provider,
model=config.model,
api_key=config.api_key,
base_url=config.base_url,
dimensions=dimensions,
batch_size=config.batch_size,
)
@router.put("/config")
async def update_config(
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
更新嵌入模型配置已禁用固定从 .env 读取
"""
return {"message": "嵌入模型配置已锁定,请在 .env 文件中进行修改", "provider": settings.EMBEDDING_PROVIDER, "model": settings.EMBEDDING_MODEL}
@router.post("/test", response_model=TestEmbeddingResponse)
async def test_embedding(
request: TestEmbeddingRequest,
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""测试当前系统的嵌入模型配置"""
import time
from app.services.rag.embeddings import EmbeddingService
try:
start_time = time.time()
# 始终使用系统的真实配置进行测试
# API Key 优先级EMBEDDING_API_KEY > LLM_API_KEY
api_key = getattr(settings, "EMBEDDING_API_KEY", None) or settings.LLM_API_KEY
service = EmbeddingService(
provider=settings.EMBEDDING_PROVIDER,
model=settings.EMBEDDING_MODEL,
api_key=api_key,
base_url=settings.EMBEDDING_BASE_URL,
cache_enabled=False,
)
# 执行嵌入
embedding = await service.embed(request.test_text)
latency_ms = int((time.time() - start_time) * 1000)
return TestEmbeddingResponse(
success=True,
message=f"嵌入成功! 提供商: {settings.EMBEDDING_PROVIDER}, 维度: {len(embedding)}",
dimensions=len(embedding),
sample_embedding=embedding[:5], # 返回前 5 维
latency_ms=latency_ms,
)
except Exception as e:
print(f"❌ 嵌入测试失败: {str(e)}")
return TestEmbeddingResponse(
success=False,
message=f"嵌入失败: {str(e)}",
)
@router.get("/models/{provider}")
async def get_provider_models(
provider: str,
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""
获取指定提供商的模型列表
"""
provider_info = next((p for p in EMBEDDING_PROVIDERS if p.id == provider), None)
if not provider_info:
raise HTTPException(status_code=404, detail=f"提供商不存在: {provider}")
return {
"provider": provider,
"models": provider_info.models,
"default_model": provider_info.default_model,
"requires_api_key": provider_info.requires_api_key,
}
def _get_model_dimensions(provider: str, model: str) -> int:
"""获取模型维度"""
dimensions_map = {
# OpenAI
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
# Ollama
"nomic-embed-text": 768,
"mxbai-embed-large": 1024,
"all-minilm": 384,
"snowflake-arctic-embed": 1024,
# Cohere
"embed-english-v3.0": 1024,
"embed-multilingual-v3.0": 1024,
"embed-english-light-v3.0": 384,
"embed-multilingual-light-v3.0": 384,
# HuggingFace
"sentence-transformers/all-MiniLM-L6-v2": 384,
"sentence-transformers/all-mpnet-base-v2": 768,
"BAAI/bge-large-zh-v1.5": 1024,
"BAAI/bge-m3": 1024,
# Jina
"jina-embeddings-v2-base-code": 768,
"jina-embeddings-v2-base-en": 768,
"jina-embeddings-v2-base-zh": 768,
# Qwen (DashScope)
"text-embedding-v4": 1024, # 支持维度: 2048, 1536, 1024(默认), 768, 512, 256, 128, 64
"text-embedding-v3": 1024, # 支持维度: 1024(默认), 768, 512, 256, 128, 64
"text-embedding-v2": 1536, # 支持维度: 1536
}
return dimensions_map.get(model, 768)