CodeReview/backend/app/api/v1/endpoints/config.py

219 lines
7.4 KiB
Python

"""
用户配置API端点
"""
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from pydantic import BaseModel
import json
from app.api import deps
from app.db.session import get_db
from app.models.user_config import UserConfig
from app.models.user import User
from app.core.config import settings
router = APIRouter()
class LLMConfigSchema(BaseModel):
"""LLM配置Schema"""
llmProvider: Optional[str] = None
llmApiKey: Optional[str] = None
llmModel: Optional[str] = None
llmBaseUrl: Optional[str] = None
llmTimeout: Optional[int] = None
llmTemperature: Optional[float] = None
llmMaxTokens: Optional[int] = None
llmCustomHeaders: Optional[str] = None
# 平台专用配置
geminiApiKey: Optional[str] = None
openaiApiKey: Optional[str] = None
claudeApiKey: Optional[str] = None
qwenApiKey: Optional[str] = None
deepseekApiKey: Optional[str] = None
zhipuApiKey: Optional[str] = None
moonshotApiKey: Optional[str] = None
baiduApiKey: Optional[str] = None
minimaxApiKey: Optional[str] = None
doubaoApiKey: Optional[str] = None
ollamaBaseUrl: Optional[str] = None
class OtherConfigSchema(BaseModel):
"""其他配置Schema"""
githubToken: Optional[str] = None
gitlabToken: Optional[str] = None
maxAnalyzeFiles: Optional[int] = None
llmConcurrency: Optional[int] = None
llmGapMs: Optional[int] = None
outputLanguage: Optional[str] = None
class UserConfigRequest(BaseModel):
"""用户配置请求"""
llmConfig: Optional[LLMConfigSchema] = None
otherConfig: Optional[OtherConfigSchema] = None
class UserConfigResponse(BaseModel):
"""用户配置响应"""
id: str
user_id: str
llmConfig: dict
otherConfig: dict
created_at: str
updated_at: Optional[str] = None
class Config:
from_attributes = True
def get_default_config() -> dict:
"""获取系统默认配置"""
return {
"llmConfig": {
"llmProvider": settings.LLM_PROVIDER,
"llmApiKey": "",
"llmModel": settings.LLM_MODEL or "",
"llmBaseUrl": settings.LLM_BASE_URL or "",
"llmTimeout": settings.LLM_TIMEOUT * 1000, # 转换为毫秒
"llmTemperature": settings.LLM_TEMPERATURE,
"llmMaxTokens": settings.LLM_MAX_TOKENS,
"llmCustomHeaders": "",
"geminiApiKey": settings.GEMINI_API_KEY or "",
"openaiApiKey": settings.OPENAI_API_KEY or "",
"claudeApiKey": settings.CLAUDE_API_KEY or "",
"qwenApiKey": settings.QWEN_API_KEY or "",
"deepseekApiKey": settings.DEEPSEEK_API_KEY or "",
"zhipuApiKey": settings.ZHIPU_API_KEY or "",
"moonshotApiKey": settings.MOONSHOT_API_KEY or "",
"baiduApiKey": settings.BAIDU_API_KEY or "",
"minimaxApiKey": settings.MINIMAX_API_KEY or "",
"doubaoApiKey": settings.DOUBAO_API_KEY or "",
"ollamaBaseUrl": settings.OLLAMA_BASE_URL or "http://localhost:11434/v1",
},
"otherConfig": {
"githubToken": settings.GITHUB_TOKEN or "",
"gitlabToken": settings.GITLAB_TOKEN or "",
"maxAnalyzeFiles": settings.MAX_ANALYZE_FILES,
"llmConcurrency": settings.LLM_CONCURRENCY,
"llmGapMs": settings.LLM_GAP_MS,
"outputLanguage": settings.OUTPUT_LANGUAGE,
}
}
@router.get("/defaults")
async def get_default_config_endpoint() -> Any:
"""获取系统默认配置(无需认证)"""
return get_default_config()
@router.get("/me", response_model=UserConfigResponse)
async def get_my_config(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""获取当前用户的配置(合并用户配置和系统默认配置)"""
result = await db.execute(
select(UserConfig).where(UserConfig.user_id == current_user.id)
)
config = result.scalar_one_or_none()
# 获取系统默认配置
default_config = get_default_config()
if not config:
# 返回系统默认配置
return UserConfigResponse(
id="",
user_id=current_user.id,
llmConfig=default_config["llmConfig"],
otherConfig=default_config["otherConfig"],
created_at="",
)
# 合并用户配置和默认配置(用户配置优先)
user_llm_config = json.loads(config.llm_config) if config.llm_config else {}
user_other_config = json.loads(config.other_config) if config.other_config else {}
merged_llm_config = {**default_config["llmConfig"], **user_llm_config}
merged_other_config = {**default_config["otherConfig"], **user_other_config}
return UserConfigResponse(
id=config.id,
user_id=config.user_id,
llmConfig=merged_llm_config,
otherConfig=merged_other_config,
created_at=config.created_at.isoformat() if config.created_at else "",
updated_at=config.updated_at.isoformat() if config.updated_at else None,
)
@router.put("/me", response_model=UserConfigResponse)
async def update_my_config(
config_in: UserConfigRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""更新当前用户的配置"""
result = await db.execute(
select(UserConfig).where(UserConfig.user_id == current_user.id)
)
config = result.scalar_one_or_none()
if not config:
# 创建新配置
config = UserConfig(
user_id=current_user.id,
llm_config=json.dumps(config_in.llmConfig.dict(exclude_none=True) if config_in.llmConfig else {}),
other_config=json.dumps(config_in.otherConfig.dict(exclude_none=True) if config_in.otherConfig else {}),
)
db.add(config)
else:
# 更新现有配置
if config_in.llmConfig:
existing_llm = json.loads(config.llm_config) if config.llm_config else {}
existing_llm.update(config_in.llmConfig.dict(exclude_none=True))
config.llm_config = json.dumps(existing_llm)
if config_in.otherConfig:
existing_other = json.loads(config.other_config) if config.other_config else {}
existing_other.update(config_in.otherConfig.dict(exclude_none=True))
config.other_config = json.dumps(existing_other)
await db.commit()
await db.refresh(config)
return UserConfigResponse(
id=config.id,
user_id=config.user_id,
llmConfig=json.loads(config.llm_config) if config.llm_config else {},
otherConfig=json.loads(config.other_config) if config.other_config else {},
created_at=config.created_at.isoformat() if config.created_at else "",
updated_at=config.updated_at.isoformat() if config.updated_at else None,
)
@router.delete("/me")
async def delete_my_config(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""删除当前用户的配置(恢复为默认)"""
result = await db.execute(
select(UserConfig).where(UserConfig.user_id == current_user.id)
)
config = result.scalar_one_or_none()
if config:
await db.delete(config)
await db.commit()
return {"message": "配置已删除"}