""" 用户配置API端点 """ from typing import Any, Optional from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from pydantic import BaseModel import json from app.api import deps from app.db.session import get_db from app.models.user_config import UserConfig from app.models.user import User from app.core.config import settings from app.core.encryption import encrypt_sensitive_data, decrypt_sensitive_data router = APIRouter() # 需要加密的敏感字段列表 SENSITIVE_LLM_FIELDS = [ 'llmApiKey', 'geminiApiKey', 'openaiApiKey', 'claudeApiKey', 'qwenApiKey', 'deepseekApiKey', 'zhipuApiKey', 'moonshotApiKey', 'baiduApiKey', 'minimaxApiKey', 'doubaoApiKey' ] SENSITIVE_OTHER_FIELDS = ['githubToken', 'gitlabToken'] def encrypt_config(config: dict, sensitive_fields: list) -> dict: """加密配置中的敏感字段""" encrypted = config.copy() for field in sensitive_fields: if field in encrypted and encrypted[field]: encrypted[field] = encrypt_sensitive_data(encrypted[field]) return encrypted def decrypt_config(config: dict, sensitive_fields: list) -> dict: """解密配置中的敏感字段""" decrypted = config.copy() for field in sensitive_fields: if field in decrypted and decrypted[field]: decrypted[field] = decrypt_sensitive_data(decrypted[field]) return decrypted class LLMConfigSchema(BaseModel): """LLM配置Schema""" llmProvider: Optional[str] = None llmApiKey: Optional[str] = None llmModel: Optional[str] = None llmBaseUrl: Optional[str] = None llmTimeout: Optional[int] = None llmTemperature: Optional[float] = None llmMaxTokens: Optional[int] = None llmCustomHeaders: Optional[str] = None # 平台专用配置 geminiApiKey: Optional[str] = None openaiApiKey: Optional[str] = None claudeApiKey: Optional[str] = None qwenApiKey: Optional[str] = None deepseekApiKey: Optional[str] = None zhipuApiKey: Optional[str] = None moonshotApiKey: Optional[str] = None baiduApiKey: Optional[str] = None minimaxApiKey: Optional[str] = None doubaoApiKey: Optional[str] = None ollamaBaseUrl: Optional[str] = None class OtherConfigSchema(BaseModel): """其他配置Schema""" githubToken: Optional[str] = None gitlabToken: Optional[str] = None maxAnalyzeFiles: Optional[int] = None llmConcurrency: Optional[int] = None llmGapMs: Optional[int] = None outputLanguage: Optional[str] = None class UserConfigRequest(BaseModel): """用户配置请求""" llmConfig: Optional[LLMConfigSchema] = None otherConfig: Optional[OtherConfigSchema] = None class UserConfigResponse(BaseModel): """用户配置响应""" id: str user_id: str llmConfig: dict otherConfig: dict created_at: str updated_at: Optional[str] = None class Config: from_attributes = True def get_default_config() -> dict: """获取系统默认配置""" return { "llmConfig": { "llmProvider": settings.LLM_PROVIDER, "llmApiKey": "", "llmModel": settings.LLM_MODEL or "", "llmBaseUrl": settings.LLM_BASE_URL or "", "llmTimeout": settings.LLM_TIMEOUT * 1000, # 转换为毫秒 "llmTemperature": settings.LLM_TEMPERATURE, "llmMaxTokens": settings.LLM_MAX_TOKENS, "llmCustomHeaders": "", "geminiApiKey": settings.GEMINI_API_KEY or "", "openaiApiKey": settings.OPENAI_API_KEY or "", "claudeApiKey": settings.CLAUDE_API_KEY or "", "qwenApiKey": settings.QWEN_API_KEY or "", "deepseekApiKey": settings.DEEPSEEK_API_KEY or "", "zhipuApiKey": settings.ZHIPU_API_KEY or "", "moonshotApiKey": settings.MOONSHOT_API_KEY or "", "baiduApiKey": settings.BAIDU_API_KEY or "", "minimaxApiKey": settings.MINIMAX_API_KEY or "", "doubaoApiKey": settings.DOUBAO_API_KEY or "", "ollamaBaseUrl": settings.OLLAMA_BASE_URL or "http://localhost:11434/v1", }, "otherConfig": { "githubToken": settings.GITHUB_TOKEN or "", "gitlabToken": settings.GITLAB_TOKEN or "", "maxAnalyzeFiles": settings.MAX_ANALYZE_FILES, "llmConcurrency": settings.LLM_CONCURRENCY, "llmGapMs": settings.LLM_GAP_MS, "outputLanguage": settings.OUTPUT_LANGUAGE, } } @router.get("/defaults") async def get_default_config_endpoint() -> Any: """获取系统默认配置(无需认证)""" return get_default_config() @router.get("/me", response_model=UserConfigResponse) async def get_my_config( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: """获取当前用户的配置(合并用户配置和系统默认配置)""" result = await db.execute( select(UserConfig).where(UserConfig.user_id == current_user.id) ) config = result.scalar_one_or_none() # 获取系统默认配置 default_config = get_default_config() if not config: print(f"[Config] 用户 {current_user.id} 没有保存的配置,返回默认配置") # 返回系统默认配置 return UserConfigResponse( id="", user_id=current_user.id, llmConfig=default_config["llmConfig"], otherConfig=default_config["otherConfig"], created_at="", ) # 合并用户配置和默认配置(用户配置优先) user_llm_config = json.loads(config.llm_config) if config.llm_config else {} user_other_config = json.loads(config.other_config) if config.other_config else {} # 解密敏感字段 user_llm_config = decrypt_config(user_llm_config, SENSITIVE_LLM_FIELDS) user_other_config = decrypt_config(user_other_config, SENSITIVE_OTHER_FIELDS) print(f"[Config] 用户 {current_user.id} 的保存配置:") print(f" - llmProvider: {user_llm_config.get('llmProvider')}") print(f" - llmApiKey: {'***' + user_llm_config.get('llmApiKey', '')[-4:] if user_llm_config.get('llmApiKey') else '(空)'}") print(f" - llmModel: {user_llm_config.get('llmModel')}") merged_llm_config = {**default_config["llmConfig"], **user_llm_config} merged_other_config = {**default_config["otherConfig"], **user_other_config} return UserConfigResponse( id=config.id, user_id=config.user_id, llmConfig=merged_llm_config, otherConfig=merged_other_config, created_at=config.created_at.isoformat() if config.created_at else "", updated_at=config.updated_at.isoformat() if config.updated_at else None, ) @router.put("/me", response_model=UserConfigResponse) async def update_my_config( config_in: UserConfigRequest, db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: """更新当前用户的配置""" result = await db.execute( select(UserConfig).where(UserConfig.user_id == current_user.id) ) config = result.scalar_one_or_none() # 准备要保存的配置数据(加密敏感字段) llm_data = config_in.llmConfig.dict(exclude_none=True) if config_in.llmConfig else {} other_data = config_in.otherConfig.dict(exclude_none=True) if config_in.otherConfig else {} # 加密敏感字段 llm_data_encrypted = encrypt_config(llm_data, SENSITIVE_LLM_FIELDS) other_data_encrypted = encrypt_config(other_data, SENSITIVE_OTHER_FIELDS) if not config: # 创建新配置 config = UserConfig( user_id=current_user.id, llm_config=json.dumps(llm_data_encrypted), other_config=json.dumps(other_data_encrypted), ) db.add(config) else: # 更新现有配置 if config_in.llmConfig: existing_llm = json.loads(config.llm_config) if config.llm_config else {} # 先解密现有数据,再合并新数据,最后加密 existing_llm = decrypt_config(existing_llm, SENSITIVE_LLM_FIELDS) existing_llm.update(llm_data) # 使用未加密的新数据合并 config.llm_config = json.dumps(encrypt_config(existing_llm, SENSITIVE_LLM_FIELDS)) if config_in.otherConfig: existing_other = json.loads(config.other_config) if config.other_config else {} # 先解密现有数据,再合并新数据,最后加密 existing_other = decrypt_config(existing_other, SENSITIVE_OTHER_FIELDS) existing_other.update(other_data) # 使用未加密的新数据合并 config.other_config = json.dumps(encrypt_config(existing_other, SENSITIVE_OTHER_FIELDS)) await db.commit() await db.refresh(config) # 获取系统默认配置并合并(与 get_my_config 保持一致) default_config = get_default_config() user_llm_config = json.loads(config.llm_config) if config.llm_config else {} user_other_config = json.loads(config.other_config) if config.other_config else {} # 解密后返回给前端 user_llm_config = decrypt_config(user_llm_config, SENSITIVE_LLM_FIELDS) user_other_config = decrypt_config(user_other_config, SENSITIVE_OTHER_FIELDS) merged_llm_config = {**default_config["llmConfig"], **user_llm_config} merged_other_config = {**default_config["otherConfig"], **user_other_config} return UserConfigResponse( id=config.id, user_id=config.user_id, llmConfig=merged_llm_config, otherConfig=merged_other_config, created_at=config.created_at.isoformat() if config.created_at else "", updated_at=config.updated_at.isoformat() if config.updated_at else None, ) @router.delete("/me") async def delete_my_config( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: """删除当前用户的配置(恢复为默认)""" result = await db.execute( select(UserConfig).where(UserConfig.user_id == current_user.id) ) config = result.scalar_one_or_none() if config: await db.delete(config) await db.commit() return {"message": "配置已删除"} class LLMTestRequest(BaseModel): """LLM测试请求""" provider: str apiKey: str model: Optional[str] = None baseUrl: Optional[str] = None class LLMTestResponse(BaseModel): """LLM测试响应""" success: bool message: str model: Optional[str] = None response: Optional[str] = None # 调试信息 debug: Optional[dict] = None @router.post("/test-llm", response_model=LLMTestResponse) async def test_llm_connection( request: LLMTestRequest, current_user: User = Depends(deps.get_current_user), ) -> Any: """测试LLM连接是否正常""" from app.services.llm.factory import LLMFactory, NATIVE_ONLY_PROVIDERS from app.services.llm.adapters import LiteLLMAdapter, BaiduAdapter, MinimaxAdapter, DoubaoAdapter from app.services.llm.types import LLMConfig, LLMProvider, LLMRequest, LLMMessage, DEFAULT_MODELS, DEFAULT_BASE_URLS import traceback import time start_time = time.time() debug_info = { "provider": request.provider, "model_requested": request.model, "base_url_requested": request.baseUrl, "api_key_length": len(request.apiKey) if request.apiKey else 0, "api_key_prefix": request.apiKey[:8] + "..." if request.apiKey and len(request.apiKey) > 8 else "(empty)", } try: # 解析provider provider_map = { 'gemini': LLMProvider.GEMINI, 'openai': LLMProvider.OPENAI, 'claude': LLMProvider.CLAUDE, 'qwen': LLMProvider.QWEN, 'deepseek': LLMProvider.DEEPSEEK, 'zhipu': LLMProvider.ZHIPU, 'moonshot': LLMProvider.MOONSHOT, 'baidu': LLMProvider.BAIDU, 'minimax': LLMProvider.MINIMAX, 'doubao': LLMProvider.DOUBAO, 'ollama': LLMProvider.OLLAMA, } provider = provider_map.get(request.provider.lower()) if not provider: debug_info["error_type"] = "unsupported_provider" return LLMTestResponse( success=False, message=f"不支持的LLM提供商: {request.provider}", debug=debug_info ) # 获取默认模型 model = request.model or DEFAULT_MODELS.get(provider) base_url = request.baseUrl or DEFAULT_BASE_URLS.get(provider, "") debug_info["model_used"] = model debug_info["base_url_used"] = base_url debug_info["is_native_adapter"] = provider in NATIVE_ONLY_PROVIDERS print(f"[LLM Test] 开始测试: provider={provider.value}, model={model}, base_url={base_url}") # 创建配置 config = LLMConfig( provider=provider, api_key=request.apiKey, model=model, base_url=request.baseUrl, timeout=30, # 测试使用较短的超时时间 max_tokens=50, # 测试使用较少的token ) # 直接创建新的适配器实例(不使用缓存),确保使用最新的配置 if provider in NATIVE_ONLY_PROVIDERS: native_adapter_map = { LLMProvider.BAIDU: BaiduAdapter, LLMProvider.MINIMAX: MinimaxAdapter, LLMProvider.DOUBAO: DoubaoAdapter, } adapter = native_adapter_map[provider](config) debug_info["adapter_type"] = type(adapter).__name__ else: adapter = LiteLLMAdapter(config) debug_info["adapter_type"] = "LiteLLMAdapter" # 获取 LiteLLM 实际使用的模型名 debug_info["litellm_model"] = getattr(adapter, '_get_model_name', lambda: model)() if hasattr(adapter, '_get_model_name') else model test_request = LLMRequest( messages=[ LLMMessage(role="user", content="Say 'Hello' in one word.") ], max_tokens=50, ) print(f"[LLM Test] 发送测试请求...") response = await adapter.complete(test_request) elapsed_time = time.time() - start_time debug_info["elapsed_time_ms"] = round(elapsed_time * 1000, 2) # 验证响应内容 if not response or not response.content: debug_info["error_type"] = "empty_response" debug_info["raw_response"] = str(response) if response else None print(f"[LLM Test] 空响应: {response}") return LLMTestResponse( success=False, message="LLM 返回空响应,请检查 API Key 和配置", debug=debug_info ) debug_info["response_length"] = len(response.content) debug_info["usage"] = { "prompt_tokens": getattr(response, 'prompt_tokens', None), "completion_tokens": getattr(response, 'completion_tokens', None), "total_tokens": getattr(response, 'total_tokens', None), } print(f"[LLM Test] 成功! 响应: {response.content[:50]}... 耗时: {elapsed_time:.2f}s") return LLMTestResponse( success=True, message=f"连接成功 ({elapsed_time:.2f}s)", model=model, response=response.content[:100] if response.content else None, debug=debug_info ) except Exception as e: elapsed_time = time.time() - start_time error_msg = str(e) error_type = type(e).__name__ debug_info["elapsed_time_ms"] = round(elapsed_time * 1000, 2) debug_info["error_type"] = error_type debug_info["error_message"] = error_msg debug_info["traceback"] = traceback.format_exc() # 提取 LLMError 中的 api_response if hasattr(e, 'api_response') and e.api_response: debug_info["api_response"] = e.api_response if hasattr(e, 'status_code') and e.status_code: debug_info["status_code"] = e.status_code print(f"[LLM Test] 失败: {error_type}: {error_msg}") print(f"[LLM Test] Traceback:\n{traceback.format_exc()}") # 提供更友好的错误信息 friendly_message = error_msg # 优先检查余额不足(因为某些 API 用 429 表示余额不足) if any(keyword in error_msg for keyword in ["余额不足", "资源包", "充值", "quota", "insufficient", "balance", "402"]): friendly_message = "账户余额不足或配额已用尽,请充值后重试" debug_info["error_category"] = "insufficient_balance" elif "401" in error_msg or "invalid_api_key" in error_msg.lower() or "incorrect api key" in error_msg.lower(): friendly_message = "API Key 无效或已过期,请检查后重试" debug_info["error_category"] = "auth_invalid_key" elif "authentication" in error_msg.lower(): friendly_message = "认证失败,请检查 API Key 是否正确" debug_info["error_category"] = "auth_failed" elif "timeout" in error_msg.lower(): friendly_message = "连接超时,请检查网络或 API 地址是否正确" debug_info["error_category"] = "timeout" elif "connection" in error_msg.lower() or "connect" in error_msg.lower(): friendly_message = "无法连接到 API 服务,请检查网络或 API 地址" debug_info["error_category"] = "connection" elif "rate" in error_msg.lower() and "limit" in error_msg.lower(): friendly_message = "API 请求频率超限,请稍后重试" debug_info["error_category"] = "rate_limit" elif "model" in error_msg.lower() and ("not found" in error_msg.lower() or "does not exist" in error_msg.lower()): friendly_message = f"模型 '{debug_info.get('model_used', 'unknown')}' 不存在或无权访问" debug_info["error_category"] = "model_not_found" else: debug_info["error_category"] = "unknown" return LLMTestResponse( success=False, message=friendly_message, debug=debug_info ) @router.get("/llm-providers") async def get_llm_providers() -> Any: """获取支持的LLM提供商列表""" from app.services.llm.factory import LLMFactory from app.services.llm.types import LLMProvider, DEFAULT_BASE_URLS providers = [] for provider in LLMFactory.get_supported_providers(): providers.append({ "id": provider.value, "name": provider.value.upper(), "defaultModel": LLMFactory.get_default_model(provider), "models": LLMFactory.get_available_models(provider), "defaultBaseUrl": DEFAULT_BASE_URLS.get(provider, ""), }) return {"providers": providers}