219 lines
7.4 KiB
Python
219 lines
7.4 KiB
Python
"""
|
|
用户配置API端点
|
|
"""
|
|
|
|
from typing import Any, Optional
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from pydantic import BaseModel
|
|
import json
|
|
|
|
from app.api import deps
|
|
from app.db.session import get_db
|
|
from app.models.user_config import UserConfig
|
|
from app.models.user import User
|
|
from app.core.config import settings
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class LLMConfigSchema(BaseModel):
|
|
"""LLM配置Schema"""
|
|
llmProvider: Optional[str] = None
|
|
llmApiKey: Optional[str] = None
|
|
llmModel: Optional[str] = None
|
|
llmBaseUrl: Optional[str] = None
|
|
llmTimeout: Optional[int] = None
|
|
llmTemperature: Optional[float] = None
|
|
llmMaxTokens: Optional[int] = None
|
|
llmCustomHeaders: Optional[str] = None
|
|
|
|
# 平台专用配置
|
|
geminiApiKey: Optional[str] = None
|
|
openaiApiKey: Optional[str] = None
|
|
claudeApiKey: Optional[str] = None
|
|
qwenApiKey: Optional[str] = None
|
|
deepseekApiKey: Optional[str] = None
|
|
zhipuApiKey: Optional[str] = None
|
|
moonshotApiKey: Optional[str] = None
|
|
baiduApiKey: Optional[str] = None
|
|
minimaxApiKey: Optional[str] = None
|
|
doubaoApiKey: Optional[str] = None
|
|
ollamaBaseUrl: Optional[str] = None
|
|
|
|
|
|
class OtherConfigSchema(BaseModel):
|
|
"""其他配置Schema"""
|
|
githubToken: Optional[str] = None
|
|
gitlabToken: Optional[str] = None
|
|
maxAnalyzeFiles: Optional[int] = None
|
|
llmConcurrency: Optional[int] = None
|
|
llmGapMs: Optional[int] = None
|
|
outputLanguage: Optional[str] = None
|
|
|
|
|
|
class UserConfigRequest(BaseModel):
|
|
"""用户配置请求"""
|
|
llmConfig: Optional[LLMConfigSchema] = None
|
|
otherConfig: Optional[OtherConfigSchema] = None
|
|
|
|
|
|
class UserConfigResponse(BaseModel):
|
|
"""用户配置响应"""
|
|
id: str
|
|
user_id: str
|
|
llmConfig: dict
|
|
otherConfig: dict
|
|
created_at: str
|
|
updated_at: Optional[str] = None
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
|
|
def get_default_config() -> dict:
|
|
"""获取系统默认配置"""
|
|
return {
|
|
"llmConfig": {
|
|
"llmProvider": settings.LLM_PROVIDER,
|
|
"llmApiKey": "",
|
|
"llmModel": settings.LLM_MODEL or "",
|
|
"llmBaseUrl": settings.LLM_BASE_URL or "",
|
|
"llmTimeout": settings.LLM_TIMEOUT * 1000, # 转换为毫秒
|
|
"llmTemperature": settings.LLM_TEMPERATURE,
|
|
"llmMaxTokens": settings.LLM_MAX_TOKENS,
|
|
"llmCustomHeaders": "",
|
|
"geminiApiKey": settings.GEMINI_API_KEY or "",
|
|
"openaiApiKey": settings.OPENAI_API_KEY or "",
|
|
"claudeApiKey": settings.CLAUDE_API_KEY or "",
|
|
"qwenApiKey": settings.QWEN_API_KEY or "",
|
|
"deepseekApiKey": settings.DEEPSEEK_API_KEY or "",
|
|
"zhipuApiKey": settings.ZHIPU_API_KEY or "",
|
|
"moonshotApiKey": settings.MOONSHOT_API_KEY or "",
|
|
"baiduApiKey": settings.BAIDU_API_KEY or "",
|
|
"minimaxApiKey": settings.MINIMAX_API_KEY or "",
|
|
"doubaoApiKey": settings.DOUBAO_API_KEY or "",
|
|
"ollamaBaseUrl": settings.OLLAMA_BASE_URL or "http://localhost:11434/v1",
|
|
},
|
|
"otherConfig": {
|
|
"githubToken": settings.GITHUB_TOKEN or "",
|
|
"gitlabToken": settings.GITLAB_TOKEN or "",
|
|
"maxAnalyzeFiles": settings.MAX_ANALYZE_FILES,
|
|
"llmConcurrency": settings.LLM_CONCURRENCY,
|
|
"llmGapMs": settings.LLM_GAP_MS,
|
|
"outputLanguage": settings.OUTPUT_LANGUAGE,
|
|
}
|
|
}
|
|
|
|
|
|
@router.get("/defaults")
|
|
async def get_default_config_endpoint() -> Any:
|
|
"""获取系统默认配置(无需认证)"""
|
|
return get_default_config()
|
|
|
|
|
|
@router.get("/me", response_model=UserConfigResponse)
|
|
async def get_my_config(
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(deps.get_current_user),
|
|
) -> Any:
|
|
"""获取当前用户的配置(合并用户配置和系统默认配置)"""
|
|
result = await db.execute(
|
|
select(UserConfig).where(UserConfig.user_id == current_user.id)
|
|
)
|
|
config = result.scalar_one_or_none()
|
|
|
|
# 获取系统默认配置
|
|
default_config = get_default_config()
|
|
|
|
if not config:
|
|
# 返回系统默认配置
|
|
return UserConfigResponse(
|
|
id="",
|
|
user_id=current_user.id,
|
|
llmConfig=default_config["llmConfig"],
|
|
otherConfig=default_config["otherConfig"],
|
|
created_at="",
|
|
)
|
|
|
|
# 合并用户配置和默认配置(用户配置优先)
|
|
user_llm_config = json.loads(config.llm_config) if config.llm_config else {}
|
|
user_other_config = json.loads(config.other_config) if config.other_config else {}
|
|
|
|
merged_llm_config = {**default_config["llmConfig"], **user_llm_config}
|
|
merged_other_config = {**default_config["otherConfig"], **user_other_config}
|
|
|
|
return UserConfigResponse(
|
|
id=config.id,
|
|
user_id=config.user_id,
|
|
llmConfig=merged_llm_config,
|
|
otherConfig=merged_other_config,
|
|
created_at=config.created_at.isoformat() if config.created_at else "",
|
|
updated_at=config.updated_at.isoformat() if config.updated_at else None,
|
|
)
|
|
|
|
|
|
@router.put("/me", response_model=UserConfigResponse)
|
|
async def update_my_config(
|
|
config_in: UserConfigRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(deps.get_current_user),
|
|
) -> Any:
|
|
"""更新当前用户的配置"""
|
|
result = await db.execute(
|
|
select(UserConfig).where(UserConfig.user_id == current_user.id)
|
|
)
|
|
config = result.scalar_one_or_none()
|
|
|
|
if not config:
|
|
# 创建新配置
|
|
config = UserConfig(
|
|
user_id=current_user.id,
|
|
llm_config=json.dumps(config_in.llmConfig.dict(exclude_none=True) if config_in.llmConfig else {}),
|
|
other_config=json.dumps(config_in.otherConfig.dict(exclude_none=True) if config_in.otherConfig else {}),
|
|
)
|
|
db.add(config)
|
|
else:
|
|
# 更新现有配置
|
|
if config_in.llmConfig:
|
|
existing_llm = json.loads(config.llm_config) if config.llm_config else {}
|
|
existing_llm.update(config_in.llmConfig.dict(exclude_none=True))
|
|
config.llm_config = json.dumps(existing_llm)
|
|
|
|
if config_in.otherConfig:
|
|
existing_other = json.loads(config.other_config) if config.other_config else {}
|
|
existing_other.update(config_in.otherConfig.dict(exclude_none=True))
|
|
config.other_config = json.dumps(existing_other)
|
|
|
|
await db.commit()
|
|
await db.refresh(config)
|
|
|
|
return UserConfigResponse(
|
|
id=config.id,
|
|
user_id=config.user_id,
|
|
llmConfig=json.loads(config.llm_config) if config.llm_config else {},
|
|
otherConfig=json.loads(config.other_config) if config.other_config else {},
|
|
created_at=config.created_at.isoformat() if config.created_at else "",
|
|
updated_at=config.updated_at.isoformat() if config.updated_at else None,
|
|
)
|
|
|
|
|
|
@router.delete("/me")
|
|
async def delete_my_config(
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(deps.get_current_user),
|
|
) -> Any:
|
|
"""删除当前用户的配置(恢复为默认)"""
|
|
result = await db.execute(
|
|
select(UserConfig).where(UserConfig.user_id == current_user.id)
|
|
)
|
|
config = result.scalar_one_or_none()
|
|
|
|
if config:
|
|
await db.delete(config)
|
|
await db.commit()
|
|
|
|
return {"message": "配置已删除"}
|
|
|