CodeReview/backend/app/api/v1/endpoints/config.py

425 lines
15 KiB
Python

"""
用户配置API端点
"""
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from pydantic import BaseModel
import json
from app.api import deps
from app.db.session import get_db
from app.models.user_config import UserConfig
from app.models.user import User
from app.core.config import settings
from app.core.encryption import encrypt_sensitive_data, decrypt_sensitive_data
router = APIRouter()
# 需要加密的敏感字段列表
SENSITIVE_LLM_FIELDS = [
'llmApiKey', 'geminiApiKey', 'openaiApiKey', 'claudeApiKey',
'qwenApiKey', 'deepseekApiKey', 'zhipuApiKey', 'moonshotApiKey',
'baiduApiKey', 'minimaxApiKey', 'doubaoApiKey'
]
SENSITIVE_OTHER_FIELDS = ['githubToken', 'gitlabToken']
def encrypt_config(config: dict, sensitive_fields: list) -> dict:
"""加密配置中的敏感字段"""
encrypted = config.copy()
for field in sensitive_fields:
if field in encrypted and encrypted[field]:
encrypted[field] = encrypt_sensitive_data(encrypted[field])
return encrypted
def decrypt_config(config: dict, sensitive_fields: list) -> dict:
"""解密配置中的敏感字段"""
decrypted = config.copy()
for field in sensitive_fields:
if field in decrypted and decrypted[field]:
decrypted[field] = decrypt_sensitive_data(decrypted[field])
return decrypted
class LLMConfigSchema(BaseModel):
"""LLM配置Schema"""
llmProvider: Optional[str] = None
llmApiKey: Optional[str] = None
llmModel: Optional[str] = None
llmBaseUrl: Optional[str] = None
llmTimeout: Optional[int] = None
llmTemperature: Optional[float] = None
llmMaxTokens: Optional[int] = None
llmCustomHeaders: Optional[str] = None
# 平台专用配置
geminiApiKey: Optional[str] = None
openaiApiKey: Optional[str] = None
claudeApiKey: Optional[str] = None
qwenApiKey: Optional[str] = None
deepseekApiKey: Optional[str] = None
zhipuApiKey: Optional[str] = None
moonshotApiKey: Optional[str] = None
baiduApiKey: Optional[str] = None
minimaxApiKey: Optional[str] = None
doubaoApiKey: Optional[str] = None
ollamaBaseUrl: Optional[str] = None
class OtherConfigSchema(BaseModel):
"""其他配置Schema"""
githubToken: Optional[str] = None
gitlabToken: Optional[str] = None
maxAnalyzeFiles: Optional[int] = None
llmConcurrency: Optional[int] = None
llmGapMs: Optional[int] = None
outputLanguage: Optional[str] = None
class UserConfigRequest(BaseModel):
"""用户配置请求"""
llmConfig: Optional[LLMConfigSchema] = None
otherConfig: Optional[OtherConfigSchema] = None
class UserConfigResponse(BaseModel):
"""用户配置响应"""
id: str
user_id: str
llmConfig: dict
otherConfig: dict
created_at: str
updated_at: Optional[str] = None
class Config:
from_attributes = True
def get_default_config() -> dict:
"""获取系统默认配置"""
return {
"llmConfig": {
"llmProvider": settings.LLM_PROVIDER,
"llmApiKey": "",
"llmModel": settings.LLM_MODEL or "",
"llmBaseUrl": settings.LLM_BASE_URL or "",
"llmTimeout": settings.LLM_TIMEOUT * 1000, # 转换为毫秒
"llmTemperature": settings.LLM_TEMPERATURE,
"llmMaxTokens": settings.LLM_MAX_TOKENS,
"llmCustomHeaders": "",
"geminiApiKey": settings.GEMINI_API_KEY or "",
"openaiApiKey": settings.OPENAI_API_KEY or "",
"claudeApiKey": settings.CLAUDE_API_KEY or "",
"qwenApiKey": settings.QWEN_API_KEY or "",
"deepseekApiKey": settings.DEEPSEEK_API_KEY or "",
"zhipuApiKey": settings.ZHIPU_API_KEY or "",
"moonshotApiKey": settings.MOONSHOT_API_KEY or "",
"baiduApiKey": settings.BAIDU_API_KEY or "",
"minimaxApiKey": settings.MINIMAX_API_KEY or "",
"doubaoApiKey": settings.DOUBAO_API_KEY or "",
"ollamaBaseUrl": settings.OLLAMA_BASE_URL or "http://localhost:11434/v1",
},
"otherConfig": {
"githubToken": settings.GITHUB_TOKEN or "",
"gitlabToken": settings.GITLAB_TOKEN or "",
"maxAnalyzeFiles": settings.MAX_ANALYZE_FILES,
"llmConcurrency": settings.LLM_CONCURRENCY,
"llmGapMs": settings.LLM_GAP_MS,
"outputLanguage": settings.OUTPUT_LANGUAGE,
}
}
@router.get("/defaults")
async def get_default_config_endpoint() -> Any:
"""获取系统默认配置(无需认证)"""
return get_default_config()
@router.get("/me", response_model=UserConfigResponse)
async def get_my_config(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""获取当前用户的配置(合并用户配置和系统默认配置)"""
result = await db.execute(
select(UserConfig).where(UserConfig.user_id == current_user.id)
)
config = result.scalar_one_or_none()
# 获取系统默认配置
default_config = get_default_config()
if not config:
print(f"[Config] 用户 {current_user.id} 没有保存的配置,返回默认配置")
# 返回系统默认配置
return UserConfigResponse(
id="",
user_id=current_user.id,
llmConfig=default_config["llmConfig"],
otherConfig=default_config["otherConfig"],
created_at="",
)
# 合并用户配置和默认配置(用户配置优先)
user_llm_config = json.loads(config.llm_config) if config.llm_config else {}
user_other_config = json.loads(config.other_config) if config.other_config else {}
# 解密敏感字段
user_llm_config = decrypt_config(user_llm_config, SENSITIVE_LLM_FIELDS)
user_other_config = decrypt_config(user_other_config, SENSITIVE_OTHER_FIELDS)
print(f"[Config] 用户 {current_user.id} 的保存配置:")
print(f" - llmProvider: {user_llm_config.get('llmProvider')}")
print(f" - llmApiKey: {'***' + user_llm_config.get('llmApiKey', '')[-4:] if user_llm_config.get('llmApiKey') else '(空)'}")
print(f" - llmModel: {user_llm_config.get('llmModel')}")
merged_llm_config = {**default_config["llmConfig"], **user_llm_config}
merged_other_config = {**default_config["otherConfig"], **user_other_config}
return UserConfigResponse(
id=config.id,
user_id=config.user_id,
llmConfig=merged_llm_config,
otherConfig=merged_other_config,
created_at=config.created_at.isoformat() if config.created_at else "",
updated_at=config.updated_at.isoformat() if config.updated_at else None,
)
@router.put("/me", response_model=UserConfigResponse)
async def update_my_config(
config_in: UserConfigRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""更新当前用户的配置"""
result = await db.execute(
select(UserConfig).where(UserConfig.user_id == current_user.id)
)
config = result.scalar_one_or_none()
# 准备要保存的配置数据(加密敏感字段)
llm_data = config_in.llmConfig.dict(exclude_none=True) if config_in.llmConfig else {}
other_data = config_in.otherConfig.dict(exclude_none=True) if config_in.otherConfig else {}
# 加密敏感字段
llm_data_encrypted = encrypt_config(llm_data, SENSITIVE_LLM_FIELDS)
other_data_encrypted = encrypt_config(other_data, SENSITIVE_OTHER_FIELDS)
if not config:
# 创建新配置
config = UserConfig(
user_id=current_user.id,
llm_config=json.dumps(llm_data_encrypted),
other_config=json.dumps(other_data_encrypted),
)
db.add(config)
else:
# 更新现有配置
if config_in.llmConfig:
existing_llm = json.loads(config.llm_config) if config.llm_config else {}
# 先解密现有数据,再合并新数据,最后加密
existing_llm = decrypt_config(existing_llm, SENSITIVE_LLM_FIELDS)
existing_llm.update(llm_data) # 使用未加密的新数据合并
config.llm_config = json.dumps(encrypt_config(existing_llm, SENSITIVE_LLM_FIELDS))
if config_in.otherConfig:
existing_other = json.loads(config.other_config) if config.other_config else {}
# 先解密现有数据,再合并新数据,最后加密
existing_other = decrypt_config(existing_other, SENSITIVE_OTHER_FIELDS)
existing_other.update(other_data) # 使用未加密的新数据合并
config.other_config = json.dumps(encrypt_config(existing_other, SENSITIVE_OTHER_FIELDS))
await db.commit()
await db.refresh(config)
# 获取系统默认配置并合并(与 get_my_config 保持一致)
default_config = get_default_config()
user_llm_config = json.loads(config.llm_config) if config.llm_config else {}
user_other_config = json.loads(config.other_config) if config.other_config else {}
# 解密后返回给前端
user_llm_config = decrypt_config(user_llm_config, SENSITIVE_LLM_FIELDS)
user_other_config = decrypt_config(user_other_config, SENSITIVE_OTHER_FIELDS)
merged_llm_config = {**default_config["llmConfig"], **user_llm_config}
merged_other_config = {**default_config["otherConfig"], **user_other_config}
return UserConfigResponse(
id=config.id,
user_id=config.user_id,
llmConfig=merged_llm_config,
otherConfig=merged_other_config,
created_at=config.created_at.isoformat() if config.created_at else "",
updated_at=config.updated_at.isoformat() if config.updated_at else None,
)
@router.delete("/me")
async def delete_my_config(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""删除当前用户的配置(恢复为默认)"""
result = await db.execute(
select(UserConfig).where(UserConfig.user_id == current_user.id)
)
config = result.scalar_one_or_none()
if config:
await db.delete(config)
await db.commit()
return {"message": "配置已删除"}
class LLMTestRequest(BaseModel):
"""LLM测试请求"""
provider: str
apiKey: str
model: Optional[str] = None
baseUrl: Optional[str] = None
class LLMTestResponse(BaseModel):
"""LLM测试响应"""
success: bool
message: str
model: Optional[str] = None
response: Optional[str] = None
@router.post("/test-llm", response_model=LLMTestResponse)
async def test_llm_connection(
request: LLMTestRequest,
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""测试LLM连接是否正常"""
from app.services.llm.factory import LLMFactory, NATIVE_ONLY_PROVIDERS
from app.services.llm.adapters import LiteLLMAdapter, BaiduAdapter, MinimaxAdapter, DoubaoAdapter
from app.services.llm.types import LLMConfig, LLMProvider, LLMRequest, LLMMessage, DEFAULT_MODELS
try:
# 解析provider
provider_map = {
'gemini': LLMProvider.GEMINI,
'openai': LLMProvider.OPENAI,
'claude': LLMProvider.CLAUDE,
'qwen': LLMProvider.QWEN,
'deepseek': LLMProvider.DEEPSEEK,
'zhipu': LLMProvider.ZHIPU,
'moonshot': LLMProvider.MOONSHOT,
'baidu': LLMProvider.BAIDU,
'minimax': LLMProvider.MINIMAX,
'doubao': LLMProvider.DOUBAO,
'ollama': LLMProvider.OLLAMA,
}
provider = provider_map.get(request.provider.lower())
if not provider:
return LLMTestResponse(
success=False,
message=f"不支持的LLM提供商: {request.provider}"
)
# 获取默认模型
model = request.model or DEFAULT_MODELS.get(provider)
# 创建配置
config = LLMConfig(
provider=provider,
api_key=request.apiKey,
model=model,
base_url=request.baseUrl,
timeout=30, # 测试使用较短的超时时间
max_tokens=50, # 测试使用较少的token
)
# 直接创建新的适配器实例(不使用缓存),确保使用最新的配置
if provider in NATIVE_ONLY_PROVIDERS:
native_adapter_map = {
LLMProvider.BAIDU: BaiduAdapter,
LLMProvider.MINIMAX: MinimaxAdapter,
LLMProvider.DOUBAO: DoubaoAdapter,
}
adapter = native_adapter_map[provider](config)
else:
adapter = LiteLLMAdapter(config)
test_request = LLMRequest(
messages=[
LLMMessage(role="user", content="Say 'Hello' in one word.")
],
max_tokens=50,
)
response = await adapter.complete(test_request)
# 验证响应内容
if not response or not response.content:
return LLMTestResponse(
success=False,
message="LLM 返回空响应,请检查 API Key 和配置"
)
return LLMTestResponse(
success=True,
message="LLM连接测试成功",
model=model,
response=response.content[:100] if response.content else None
)
except Exception as e:
error_msg = str(e)
# 提供更友好的错误信息
if "401" in error_msg or "invalid_api_key" in error_msg.lower() or "incorrect api key" in error_msg.lower():
return LLMTestResponse(
success=False,
message="API Key 无效或已过期,请检查后重试"
)
elif "authentication" in error_msg.lower():
return LLMTestResponse(
success=False,
message="认证失败,请检查 API Key 是否正确"
)
elif "timeout" in error_msg.lower():
return LLMTestResponse(
success=False,
message="连接超时,请检查网络或 API 地址是否正确"
)
elif "connection" in error_msg.lower():
return LLMTestResponse(
success=False,
message="无法连接到 API 服务,请检查网络或 API 地址"
)
return LLMTestResponse(
success=False,
message=f"LLM连接测试失败: {error_msg}"
)
@router.get("/llm-providers")
async def get_llm_providers() -> Any:
"""获取支持的LLM提供商列表"""
from app.services.llm.factory import LLMFactory
from app.services.llm.types import LLMProvider, DEFAULT_BASE_URLS
providers = []
for provider in LLMFactory.get_supported_providers():
providers.append({
"id": provider.value,
"name": provider.value.upper(),
"defaultModel": LLMFactory.get_default_model(provider),
"models": LLMFactory.get_available_models(provider),
"defaultBaseUrl": DEFAULT_BASE_URLS.get(provider, ""),
})
return {"providers": providers}