refactor(llm): consolidate LLM adapters with LiteLLM unified layer

- Replace individual adapter implementations (OpenAI, Claude, Gemini, DeepSeek, Qwen, Zhipu, Moonshot, Ollama) with unified LiteLLM adapter
- Keep native adapters for providers with special API formats (Baidu, MiniMax, Doubao)
- Update LLM factory to route requests through LiteLLM for supported providers
- Add test-llm endpoint to validate LLM connections with configurable timeout and token limits
- Add get-llm-providers endpoint to retrieve supported providers and their configurations
- Update config.py to ignore extra environment variables (VITE_* frontend variables)
- Refactor Baidu adapter to use new complete() method signature and improve error handling
- Update pyproject.toml dependencies to include litellm package
- Update env.example with new configuration options
- Simplify adapter initialization and reduce code duplication across multiple provider implementations
This commit is contained in:
lintsinghua 2025-11-28 16:41:39 +08:00
parent 1fc0ecd14a
commit 22c528acf1
20 changed files with 807 additions and 1548 deletions

View File

@ -216,3 +216,109 @@ async def delete_my_config(
return {"message": "配置已删除"} return {"message": "配置已删除"}
class LLMTestRequest(BaseModel):
"""LLM测试请求"""
provider: str
apiKey: str
model: Optional[str] = None
baseUrl: Optional[str] = None
class LLMTestResponse(BaseModel):
"""LLM测试响应"""
success: bool
message: str
model: Optional[str] = None
response: Optional[str] = None
@router.post("/test-llm", response_model=LLMTestResponse)
async def test_llm_connection(
request: LLMTestRequest,
current_user: User = Depends(deps.get_current_user),
) -> Any:
"""测试LLM连接是否正常"""
from app.services.llm.factory import LLMFactory
from app.services.llm.types import LLMConfig, LLMProvider, LLMRequest, LLMMessage, DEFAULT_MODELS
try:
# 解析provider
provider_map = {
'gemini': LLMProvider.GEMINI,
'openai': LLMProvider.OPENAI,
'claude': LLMProvider.CLAUDE,
'qwen': LLMProvider.QWEN,
'deepseek': LLMProvider.DEEPSEEK,
'zhipu': LLMProvider.ZHIPU,
'moonshot': LLMProvider.MOONSHOT,
'baidu': LLMProvider.BAIDU,
'minimax': LLMProvider.MINIMAX,
'doubao': LLMProvider.DOUBAO,
'ollama': LLMProvider.OLLAMA,
}
provider = provider_map.get(request.provider.lower())
if not provider:
return LLMTestResponse(
success=False,
message=f"不支持的LLM提供商: {request.provider}"
)
# 获取默认模型
model = request.model or DEFAULT_MODELS.get(provider)
# 创建配置
config = LLMConfig(
provider=provider,
api_key=request.apiKey,
model=model,
base_url=request.baseUrl,
timeout=30, # 测试使用较短的超时时间
max_tokens=50, # 测试使用较少的token
)
# 创建适配器并测试
adapter = LLMFactory.create_adapter(config)
test_request = LLMRequest(
messages=[
LLMMessage(role="user", content="Say 'Hello' in one word.")
],
max_tokens=50,
)
response = await adapter.complete(test_request)
return LLMTestResponse(
success=True,
message="LLM连接测试成功",
model=model,
response=response.content[:100] if response.content else None
)
except Exception as e:
return LLMTestResponse(
success=False,
message=f"LLM连接测试失败: {str(e)}"
)
@router.get("/llm-providers")
async def get_llm_providers() -> Any:
"""获取支持的LLM提供商列表"""
from app.services.llm.factory import LLMFactory
from app.services.llm.types import LLMProvider, DEFAULT_BASE_URLS
providers = []
for provider in LLMFactory.get_supported_providers():
providers.append({
"id": provider.value,
"name": provider.value.upper(),
"defaultModel": LLMFactory.get_default_model(provider),
"models": LLMFactory.get_available_models(provider),
"defaultBaseUrl": DEFAULT_BASE_URLS.get(provider, ""),
})
return {"providers": providers}

View File

@ -77,6 +77,7 @@ class Settings(BaseSettings):
class Config: class Config:
case_sensitive = True case_sensitive = True
env_file = ".env" env_file = ".env"
extra = "ignore" # 忽略额外的环境变量(如 VITE_* 前端变量)
settings = Settings() settings = Settings()

View File

@ -1,30 +1,22 @@
""" """
LLM适配器模块 LLM适配器模块
适配器分为两类:
1. LiteLLM 统一适配器 - 支持 OpenAI, Claude, Gemini, DeepSeek, Qwen, Zhipu, Moonshot, Ollama
2. 原生适配器 - 用于 API 格式特殊的提供商: Baidu, MiniMax, Doubao
""" """
from .openai_adapter import OpenAIAdapter # LiteLLM 统一适配器
from .gemini_adapter import GeminiAdapter from .litellm_adapter import LiteLLMAdapter
from .claude_adapter import ClaudeAdapter
from .deepseek_adapter import DeepSeekAdapter # 原生适配器 (用于 API 格式特殊的提供商)
from .qwen_adapter import QwenAdapter
from .zhipu_adapter import ZhipuAdapter
from .moonshot_adapter import MoonshotAdapter
from .baidu_adapter import BaiduAdapter from .baidu_adapter import BaiduAdapter
from .minimax_adapter import MinimaxAdapter from .minimax_adapter import MinimaxAdapter
from .doubao_adapter import DoubaoAdapter from .doubao_adapter import DoubaoAdapter
from .ollama_adapter import OllamaAdapter
__all__ = [ __all__ = [
'OpenAIAdapter', "LiteLLMAdapter",
'GeminiAdapter', "BaiduAdapter",
'ClaudeAdapter', "MinimaxAdapter",
'DeepSeekAdapter', "DoubaoAdapter",
'QwenAdapter',
'ZhipuAdapter',
'MoonshotAdapter',
'BaiduAdapter',
'MinimaxAdapter',
'DoubaoAdapter',
'OllamaAdapter',
] ]

View File

@ -3,10 +3,9 @@
""" """
import httpx import httpx
import json
from typing import Optional from typing import Optional
from ..base_adapter import BaseLLMAdapter from ..base_adapter import BaseLLMAdapter
from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError, LLMProvider, LLMUsage
class BaiduAdapter(BaseLLMAdapter): class BaiduAdapter(BaseLLMAdapter):
@ -70,8 +69,16 @@ class BaiduAdapter(BaseLLMAdapter):
return self._access_token return self._access_token
async def _do_complete(self, request: LLMRequest) -> LLMResponse: async def complete(self, request: LLMRequest) -> LLMResponse:
"""执行实际的API调用""" """执行实际的API调用"""
try:
await self.validate_config()
return await self.retry(lambda: self._send_request(request))
except Exception as error:
self.handle_error(error, "百度文心一言 API调用失败")
async def _send_request(self, request: LLMRequest) -> LLMResponse:
"""发送请求"""
access_token = await self._get_access_token() access_token = await self._get_access_token()
# 获取模型对应的API端点 # 获取模型对应的API端点
@ -84,55 +91,56 @@ class BaiduAdapter(BaseLLMAdapter):
payload = { payload = {
"messages": messages, "messages": messages,
"temperature": request.temperature or self.config.temperature, "temperature": request.temperature if request.temperature is not None else self.config.temperature,
"top_p": request.top_p or self.config.top_p, "top_p": request.top_p if request.top_p is not None else self.config.top_p,
} }
if request.max_tokens or self.config.max_tokens: if request.max_tokens or self.config.max_tokens:
payload["max_output_tokens"] = request.max_tokens or self.config.max_tokens payload["max_output_tokens"] = request.max_tokens or self.config.max_tokens
async with httpx.AsyncClient(timeout=self.config.timeout) as client: response = await self.client.post(
response = await client.post( url,
url, headers=self.build_headers(),
json=payload, json=payload
headers={"Content-Type": "application/json"} )
)
if response.status_code != 200:
if response.status_code != 200: error_data = response.json() if response.text else {}
raise LLMError( error_msg = error_data.get("error_msg", f"HTTP {response.status_code}")
f"百度API错误: {response.text}", raise Exception(f"{error_msg}")
provider="baidu",
status_code=response.status_code data = response.json()
)
if "error_code" in data:
data = response.json() raise Exception(f"百度API错误: {data.get('error_msg', '未知错误')}")
if "error_code" in data: usage = None
raise LLMError( if "usage" in data:
f"百度API错误: {data.get('error_msg', '未知错误')}", usage = LLMUsage(
provider="baidu", prompt_tokens=data["usage"].get("prompt_tokens", 0),
status_code=data.get("error_code") completion_tokens=data["usage"].get("completion_tokens", 0),
) total_tokens=data["usage"].get("total_tokens", 0)
return LLMResponse(
content=data.get("result", ""),
model=model,
usage=data.get("usage"),
finish_reason=data.get("finish_reason")
) )
return LLMResponse(
content=data.get("result", ""),
model=model,
usage=usage,
finish_reason=data.get("finish_reason")
)
async def validate_config(self) -> bool: async def validate_config(self) -> bool:
"""验证配置是否有效""" """验证配置是否有效"""
try: if not self.config.api_key:
await self._get_access_token() raise LLMError(
return True "API Key未配置",
except Exception: provider=LLMProvider.BAIDU
return False )
if ":" not in self.config.api_key:
def get_provider(self) -> str: raise LLMError(
return "baidu" "百度API需要同时提供API Key和Secret Key格式api_key:secret_key",
provider=LLMProvider.BAIDU
def get_model(self) -> str: )
return self.config.model or "ERNIE-3.5-8K" return True

View File

@ -1,94 +0,0 @@
"""
Anthropic Claude适配器
"""
from typing import Dict, Any
from ..base_adapter import BaseLLMAdapter
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
class ClaudeAdapter(BaseLLMAdapter):
"""Claude适配器"""
@property
def base_url(self) -> str:
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.CLAUDE, "https://api.anthropic.com/v1")
async def complete(self, request: LLMRequest) -> LLMResponse:
try:
await self.validate_config()
return await self.retry(lambda: self._send_request(request))
except Exception as error:
self.handle_error(error, "Claude API调用失败")
async def _send_request(self, request: LLMRequest) -> LLMResponse:
# Claude API需要将system消息分离
system_message = None
messages = []
for msg in request.messages:
if msg.role == "system":
system_message = msg.content
else:
messages.append({
"role": msg.role,
"content": msg.content
})
request_body: Dict[str, Any] = {
"model": self.config.model,
"messages": messages,
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens or 4096,
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
}
if system_message:
request_body["system"] = system_message
# 构建请求头
headers = {
"x-api-key": self.config.api_key,
"anthropic-version": "2023-06-01",
}
url = f"{self.base_url.rstrip('/')}/messages"
response = await self.client.post(
url,
headers=self.build_headers(headers),
json=request_body
)
if response.status_code != 200:
error_data = response.json() if response.text else {}
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
raise Exception(f"{error_msg}")
data = response.json()
if not data.get("content") or not data["content"][0]:
raise Exception("API响应格式异常: 缺少content字段")
usage = None
if "usage" in data:
usage = LLMUsage(
prompt_tokens=data["usage"].get("input_tokens", 0),
completion_tokens=data["usage"].get("output_tokens", 0),
total_tokens=data["usage"].get("input_tokens", 0) + data["usage"].get("output_tokens", 0)
)
return LLMResponse(
content=data["content"][0].get("text", ""),
model=data.get("model"),
usage=usage,
finish_reason=data.get("stop_reason")
)
async def validate_config(self) -> bool:
await super().validate_config()
if not self.config.model.startswith("claude-"):
raise Exception(f"无效的Claude模型: {self.config.model}")
return True

View File

@ -1,82 +0,0 @@
"""
DeepSeek适配器 - 兼容OpenAI格式
"""
from typing import Dict, Any
from ..base_adapter import BaseLLMAdapter
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
class DeepSeekAdapter(BaseLLMAdapter):
"""DeepSeek适配器"""
@property
def base_url(self) -> str:
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.DEEPSEEK, "https://api.deepseek.com")
async def complete(self, request: LLMRequest) -> LLMResponse:
try:
await self.validate_config()
return await self.retry(lambda: self._send_request(request))
except Exception as error:
self.handle_error(error, "DeepSeek API调用失败")
async def _send_request(self, request: LLMRequest) -> LLMResponse:
# DeepSeek API兼容OpenAI格式
headers = {
"Authorization": f"Bearer {self.config.api_key}",
}
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
request_body: Dict[str, Any] = {
"model": self.config.model,
"messages": messages,
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
"frequency_penalty": self.config.frequency_penalty,
"presence_penalty": self.config.presence_penalty,
}
url = f"{self.base_url.rstrip('/')}/v1/chat/completions"
response = await self.client.post(
url,
headers=self.build_headers(headers),
json=request_body
)
if response.status_code != 200:
error_data = response.json() if response.text else {}
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
raise Exception(f"{error_msg}")
data = response.json()
choice = data.get("choices", [{}])[0]
if not choice:
raise Exception("API响应格式异常: 缺少choices字段")
usage = None
if "usage" in data:
usage = LLMUsage(
prompt_tokens=data["usage"].get("prompt_tokens", 0),
completion_tokens=data["usage"].get("completion_tokens", 0),
total_tokens=data["usage"].get("total_tokens", 0)
)
return LLMResponse(
content=choice.get("message", {}).get("content", ""),
model=data.get("model"),
usage=usage,
finish_reason=choice.get("finish_reason")
)
async def validate_config(self) -> bool:
await super().validate_config()
if not self.config.model:
raise Exception("未指定DeepSeek模型")
return True

View File

@ -2,9 +2,8 @@
字节跳动豆包适配器 字节跳动豆包适配器
""" """
import httpx
from ..base_adapter import BaseLLMAdapter from ..base_adapter import BaseLLMAdapter
from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError, LLMProvider, LLMUsage
class DoubaoAdapter(BaseLLMAdapter): class DoubaoAdapter(BaseLLMAdapter):
@ -17,8 +16,16 @@ class DoubaoAdapter(BaseLLMAdapter):
super().__init__(config) super().__init__(config)
self._base_url = config.base_url or "https://ark.cn-beijing.volces.com/api/v3" self._base_url = config.base_url or "https://ark.cn-beijing.volces.com/api/v3"
async def _do_complete(self, request: LLMRequest) -> LLMResponse: async def complete(self, request: LLMRequest) -> LLMResponse:
"""执行实际的API调用""" """执行实际的API调用"""
try:
await self.validate_config()
return await self.retry(lambda: self._send_request(request))
except Exception as error:
self.handle_error(error, "豆包 API调用失败")
async def _send_request(self, request: LLMRequest) -> LLMResponse:
"""发送请求"""
url = f"{self._base_url}/chat/completions" url = f"{self._base_url}/chat/completions"
messages = [{"role": m.role, "content": m.content} for m in request.messages] messages = [{"role": m.role, "content": m.content} for m in request.messages]
@ -26,63 +33,52 @@ class DoubaoAdapter(BaseLLMAdapter):
payload = { payload = {
"model": self.config.model or "doubao-pro-32k", "model": self.config.model or "doubao-pro-32k",
"messages": messages, "messages": messages,
"temperature": request.temperature or self.config.temperature, "temperature": request.temperature if request.temperature is not None else self.config.temperature,
"top_p": request.top_p or self.config.top_p, "max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
} }
if request.max_tokens or self.config.max_tokens:
payload["max_tokens"] = request.max_tokens or self.config.max_tokens
headers = { headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.config.api_key}", "Authorization": f"Bearer {self.config.api_key}",
} }
async with httpx.AsyncClient(timeout=self.config.timeout) as client: response = await self.client.post(
response = await client.post(url, json=payload, headers=headers) url,
headers=self.build_headers(headers),
if response.status_code != 200: json=payload
raise LLMError( )
f"豆包API错误: {response.text}",
provider="doubao", if response.status_code != 200:
status_code=response.status_code error_data = response.json() if response.text else {}
) error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
raise Exception(f"{error_msg}")
data = response.json()
data = response.json()
if "error" in data: choice = data.get("choices", [{}])[0]
raise LLMError(
f"豆包API错误: {data['error'].get('message', '未知错误')}", if not choice:
provider="doubao" raise Exception("API响应格式异常: 缺少choices字段")
)
usage = None
choices = data.get("choices", []) if "usage" in data:
if not choices: usage = LLMUsage(
raise LLMError("豆包API返回空响应", provider="doubao") prompt_tokens=data["usage"].get("prompt_tokens", 0),
completion_tokens=data["usage"].get("completion_tokens", 0),
return LLMResponse( total_tokens=data["usage"].get("total_tokens", 0)
content=choices[0].get("message", {}).get("content", ""),
model=data.get("model", self.config.model or "doubao-pro-32k"),
usage=data.get("usage"),
finish_reason=choices[0].get("finish_reason")
) )
return LLMResponse(
content=choice.get("message", {}).get("content", ""),
model=data.get("model"),
usage=usage,
finish_reason=choice.get("finish_reason")
)
async def validate_config(self) -> bool: async def validate_config(self) -> bool:
"""验证配置是否有效""" """验证配置是否有效"""
try: await super().validate_config()
test_request = LLMRequest( if not self.config.model:
messages=[{"role": "user", "content": "Hi"}], raise LLMError("未指定豆包模型", provider=LLMProvider.DOUBAO)
max_tokens=10 return True
)
await self._do_complete(test_request)
return True
except Exception:
return False
def get_provider(self) -> str:
return "doubao"
def get_model(self) -> str:
return self.config.model or "doubao-pro-32k"

View File

@ -1,114 +0,0 @@
"""
Google Gemini适配器 - 支持官方API和中转站
"""
from typing import Dict, Any, List
from ..base_adapter import BaseLLMAdapter
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
class GeminiAdapter(BaseLLMAdapter):
"""Gemini适配器"""
@property
def base_url(self) -> str:
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.GEMINI, "https://generativelanguage.googleapis.com/v1beta")
async def complete(self, request: LLMRequest) -> LLMResponse:
try:
await self.validate_config()
return await self.retry(lambda: self._generate_content(request))
except Exception as error:
self.handle_error(error, "Gemini API调用失败")
async def _generate_content(self, request: LLMRequest) -> LLMResponse:
# 转换消息格式为 Gemini 格式
contents: List[Dict[str, Any]] = []
system_content = ""
for msg in request.messages:
if msg.role == "system":
system_content = msg.content
else:
role = "model" if msg.role == "assistant" else "user"
contents.append({
"role": role,
"parts": [{"text": msg.content}]
})
# 将系统消息合并到第一条用户消息
if system_content and contents:
contents[0]["parts"][0]["text"] = f"{system_content}\n\n{contents[0]['parts'][0]['text']}"
# 构建请求体
request_body = {
"contents": contents,
"generationConfig": {
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
"maxOutputTokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
"topP": request.top_p if request.top_p is not None else self.config.top_p,
}
}
# API Key 在 URL 参数中
url = f"{self.base_url}/models/{self.config.model}:generateContent?key={self.config.api_key}"
response = await self.client.post(
url,
headers=self.build_headers(),
json=request_body
)
if response.status_code != 200:
error_data = response.json() if response.text else {}
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
raise Exception(f"{error_msg}")
data = response.json()
# 解析 Gemini 响应格式
candidates = data.get("candidates", [])
if not candidates:
# 检查是否有错误信息
if "error" in data:
error_msg = data["error"].get("message", "未知错误")
raise Exception(f"Gemini API错误: {error_msg}")
raise Exception("API响应格式异常: 缺少candidates字段")
candidate = candidates[0]
if not candidate or "content" not in candidate:
raise Exception("API响应格式异常: 缺少content字段")
text_parts = candidate.get("content", {}).get("parts", [])
if not text_parts:
raise Exception("API响应格式异常: content.parts为空")
text = "".join(part.get("text", "") for part in text_parts)
# 检查响应内容是否为空
if not text or not text.strip():
finish_reason = candidate.get("finishReason", "unknown")
raise Exception(f"Gemini返回空响应 - Finish Reason: {finish_reason}")
usage = None
if "usageMetadata" in data:
usage_data = data["usageMetadata"]
usage = LLMUsage(
prompt_tokens=usage_data.get("promptTokenCount", 0),
completion_tokens=usage_data.get("candidatesTokenCount", 0),
total_tokens=usage_data.get("totalTokenCount", 0)
)
return LLMResponse(
content=text,
model=self.config.model,
usage=usage,
finish_reason=candidate.get("finishReason", "stop")
)
async def validate_config(self) -> bool:
await super().validate_config()
if not self.config.model.startswith("gemini-"):
raise Exception(f"无效的Gemini模型: {self.config.model}")
return True

View File

@ -0,0 +1,182 @@
"""
LiteLLM 统一适配器
支持通过 LiteLLM 调用多个 LLM 提供商使用统一的 OpenAI 兼容格式
"""
from typing import Dict, Any, Optional
from ..base_adapter import BaseLLMAdapter
from ..types import (
LLMConfig,
LLMRequest,
LLMResponse,
LLMUsage,
LLMProvider,
LLMError,
DEFAULT_BASE_URLS,
)
class LiteLLMAdapter(BaseLLMAdapter):
"""
LiteLLM 统一适配器
支持的提供商:
- OpenAI (openai/gpt-4o-mini)
- Claude (anthropic/claude-3-5-sonnet-20241022)
- Gemini (gemini/gemini-1.5-flash)
- DeepSeek (deepseek/deepseek-chat)
- Qwen (qwen/qwen-turbo) - 通过 OpenAI 兼容模式
- Zhipu (zhipu/glm-4-flash) - 通过 OpenAI 兼容模式
- Moonshot (moonshot/moonshot-v1-8k) - 通过 OpenAI 兼容模式
- Ollama (ollama/llama3)
"""
# LiteLLM 模型前缀映射
PROVIDER_PREFIX_MAP = {
LLMProvider.OPENAI: "openai",
LLMProvider.CLAUDE: "anthropic",
LLMProvider.GEMINI: "gemini",
LLMProvider.DEEPSEEK: "deepseek",
LLMProvider.QWEN: "openai", # 使用 OpenAI 兼容模式
LLMProvider.ZHIPU: "openai", # 使用 OpenAI 兼容模式
LLMProvider.MOONSHOT: "openai", # 使用 OpenAI 兼容模式
LLMProvider.OLLAMA: "ollama",
}
# 需要自定义 base_url 的提供商
CUSTOM_BASE_URL_PROVIDERS = {
LLMProvider.QWEN,
LLMProvider.ZHIPU,
LLMProvider.MOONSHOT,
LLMProvider.DEEPSEEK,
}
def __init__(self, config: LLMConfig):
super().__init__(config)
self._litellm_model = self._get_litellm_model()
self._api_base = self._get_api_base()
def _get_litellm_model(self) -> str:
"""获取 LiteLLM 格式的模型名称"""
provider = self.config.provider
model = self.config.model
# 对于使用 OpenAI 兼容模式的提供商,直接使用模型名
if provider in self.CUSTOM_BASE_URL_PROVIDERS:
return model
# 对于原生支持的提供商,添加前缀
prefix = self.PROVIDER_PREFIX_MAP.get(provider, "openai")
# 检查模型名是否已经包含前缀
if "/" in model:
return model
return f"{prefix}/{model}"
def _get_api_base(self) -> Optional[str]:
"""获取 API 基础 URL"""
# 优先使用用户配置的 base_url
if self.config.base_url:
return self.config.base_url
# 对于需要自定义 base_url 的提供商,使用默认值
if self.config.provider in self.CUSTOM_BASE_URL_PROVIDERS:
return DEFAULT_BASE_URLS.get(self.config.provider)
# Ollama 使用本地地址
if self.config.provider == LLMProvider.OLLAMA:
return DEFAULT_BASE_URLS.get(LLMProvider.OLLAMA, "http://localhost:11434")
return None
async def complete(self, request: LLMRequest) -> LLMResponse:
"""使用 LiteLLM 发送请求"""
try:
await self.validate_config()
return await self.retry(lambda: self._send_request(request))
except Exception as error:
self.handle_error(error, f"LiteLLM ({self.config.provider.value}) API调用失败")
async def _send_request(self, request: LLMRequest) -> LLMResponse:
"""发送请求到 LiteLLM"""
import litellm
# 构建消息
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
# 构建请求参数
kwargs: Dict[str, Any] = {
"model": self._litellm_model,
"messages": messages,
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
}
# 设置 API Key
if self.config.api_key and self.config.api_key != "ollama":
kwargs["api_key"] = self.config.api_key
# 设置 API Base URL
if self._api_base:
kwargs["api_base"] = self._api_base
# 设置超时
kwargs["timeout"] = self.config.timeout
# 对于 OpenAI 提供商,添加额外参数
if self.config.provider == LLMProvider.OPENAI:
kwargs["frequency_penalty"] = self.config.frequency_penalty
kwargs["presence_penalty"] = self.config.presence_penalty
# 调用 LiteLLM
response = await litellm.acompletion(**kwargs)
# 解析响应
choice = response.choices[0] if response.choices else None
if not choice:
raise LLMError("API响应格式异常: 缺少choices字段", self.config.provider)
usage = None
if hasattr(response, "usage") and response.usage:
usage = LLMUsage(
prompt_tokens=response.usage.prompt_tokens or 0,
completion_tokens=response.usage.completion_tokens or 0,
total_tokens=response.usage.total_tokens or 0,
)
return LLMResponse(
content=choice.message.content or "",
model=response.model,
usage=usage,
finish_reason=choice.finish_reason,
)
async def validate_config(self) -> bool:
"""验证配置"""
# Ollama 不需要 API Key
if self.config.provider == LLMProvider.OLLAMA:
if not self.config.model:
raise LLMError("未指定 Ollama 模型", LLMProvider.OLLAMA)
return True
# 其他提供商需要 API Key
if not self.config.api_key:
raise LLMError(
f"API Key未配置 ({self.config.provider.value})",
self.config.provider,
)
if not self.config.model:
raise LLMError(
f"未指定模型 ({self.config.provider.value})",
self.config.provider,
)
return True
@classmethod
def supports_provider(cls, provider: LLMProvider) -> bool:
"""检查是否支持指定的提供商"""
return provider in cls.PROVIDER_PREFIX_MAP

View File

@ -2,9 +2,8 @@
MiniMax适配器 MiniMax适配器
""" """
import httpx
from ..base_adapter import BaseLLMAdapter from ..base_adapter import BaseLLMAdapter
from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError, LLMProvider, LLMUsage
class MinimaxAdapter(BaseLLMAdapter): class MinimaxAdapter(BaseLLMAdapter):
@ -14,8 +13,16 @@ class MinimaxAdapter(BaseLLMAdapter):
super().__init__(config) super().__init__(config)
self._base_url = config.base_url or "https://api.minimax.chat/v1" self._base_url = config.base_url or "https://api.minimax.chat/v1"
async def _do_complete(self, request: LLMRequest) -> LLMResponse: async def complete(self, request: LLMRequest) -> LLMResponse:
"""执行实际的API调用""" """执行实际的API调用"""
try:
await self.validate_config()
return await self.retry(lambda: self._send_request(request))
except Exception as error:
self.handle_error(error, "MiniMax API调用失败")
async def _send_request(self, request: LLMRequest) -> LLMResponse:
"""发送请求"""
url = f"{self._base_url}/text/chatcompletion_v2" url = f"{self._base_url}/text/chatcompletion_v2"
messages = [{"role": m.role, "content": m.content} for m in request.messages] messages = [{"role": m.role, "content": m.content} for m in request.messages]
@ -23,63 +30,58 @@ class MinimaxAdapter(BaseLLMAdapter):
payload = { payload = {
"model": self.config.model or "abab6.5-chat", "model": self.config.model or "abab6.5-chat",
"messages": messages, "messages": messages,
"temperature": request.temperature or self.config.temperature, "temperature": request.temperature if request.temperature is not None else self.config.temperature,
"top_p": request.top_p or self.config.top_p, "max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
} }
if request.max_tokens or self.config.max_tokens:
payload["max_tokens"] = request.max_tokens or self.config.max_tokens
headers = { headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.config.api_key}", "Authorization": f"Bearer {self.config.api_key}",
} }
async with httpx.AsyncClient(timeout=self.config.timeout) as client: response = await self.client.post(
response = await client.post(url, json=payload, headers=headers) url,
headers=self.build_headers(headers),
if response.status_code != 200: json=payload
raise LLMError( )
f"MiniMax API错误: {response.text}",
provider="minimax", if response.status_code != 200:
status_code=response.status_code error_data = response.json() if response.text else {}
) error_msg = error_data.get("base_resp", {}).get("status_msg", f"HTTP {response.status_code}")
raise Exception(f"{error_msg}")
data = response.json()
data = response.json()
if data.get("base_resp", {}).get("status_code") != 0:
raise LLMError( # MiniMax 特殊的错误处理
f"MiniMax API错误: {data.get('base_resp', {}).get('status_msg', '未知错误')}", if data.get("base_resp", {}).get("status_code") != 0:
provider="minimax" error_msg = data.get("base_resp", {}).get("status_msg", "未知错误")
) raise Exception(f"MiniMax API错误: {error_msg}")
choices = data.get("choices", []) choice = data.get("choices", [{}])[0]
if not choices:
raise LLMError("MiniMax API返回空响应", provider="minimax") if not choice:
raise Exception("API响应格式异常: 缺少choices字段")
return LLMResponse(
content=choices[0].get("message", {}).get("content", ""), usage = None
model=self.config.model or "abab6.5-chat", if "usage" in data:
usage=data.get("usage"), usage = LLMUsage(
finish_reason=choices[0].get("finish_reason") prompt_tokens=data["usage"].get("prompt_tokens", 0),
completion_tokens=data["usage"].get("completion_tokens", 0),
total_tokens=data["usage"].get("total_tokens", 0)
) )
return LLMResponse(
content=choice.get("message", {}).get("content", ""),
model=data.get("model"),
usage=usage,
finish_reason=choice.get("finish_reason")
)
async def validate_config(self) -> bool: async def validate_config(self) -> bool:
"""验证配置是否有效""" """验证配置是否有效"""
try: await super().validate_config()
test_request = LLMRequest( if not self.config.model:
messages=[{"role": "user", "content": "Hi"}], raise LLMError("未指定MiniMax模型", provider=LLMProvider.MINIMAX)
max_tokens=10 return True
)
await self._do_complete(test_request)
return True
except Exception:
return False
def get_provider(self) -> str:
return "minimax"
def get_model(self) -> str:
return self.config.model or "abab6.5-chat"

View File

@ -1,80 +0,0 @@
"""
月之暗面 Kimi适配器 - 兼容OpenAI格式
"""
from typing import Dict, Any
from ..base_adapter import BaseLLMAdapter
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
class MoonshotAdapter(BaseLLMAdapter):
"""月之暗面Kimi适配器"""
@property
def base_url(self) -> str:
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.MOONSHOT, "https://api.moonshot.cn/v1")
async def complete(self, request: LLMRequest) -> LLMResponse:
try:
await self.validate_config()
return await self.retry(lambda: self._send_request(request))
except Exception as error:
self.handle_error(error, "Moonshot API调用失败")
async def _send_request(self, request: LLMRequest) -> LLMResponse:
# Moonshot API兼容OpenAI格式
headers = {
"Authorization": f"Bearer {self.config.api_key}",
}
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
request_body: Dict[str, Any] = {
"model": self.config.model,
"messages": messages,
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
}
url = f"{self.base_url.rstrip('/')}/chat/completions"
response = await self.client.post(
url,
headers=self.build_headers(headers),
json=request_body
)
if response.status_code != 200:
error_data = response.json() if response.text else {}
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
raise Exception(f"{error_msg}")
data = response.json()
choice = data.get("choices", [{}])[0]
if not choice:
raise Exception("API响应格式异常: 缺少choices字段")
usage = None
if "usage" in data:
usage = LLMUsage(
prompt_tokens=data["usage"].get("prompt_tokens", 0),
completion_tokens=data["usage"].get("completion_tokens", 0),
total_tokens=data["usage"].get("total_tokens", 0)
)
return LLMResponse(
content=choice.get("message", {}).get("content", ""),
model=data.get("model"),
usage=usage,
finish_reason=choice.get("finish_reason")
)
async def validate_config(self) -> bool:
await super().validate_config()
if not self.config.model:
raise Exception("未指定Moonshot模型")
return True

View File

@ -1,83 +0,0 @@
"""
Ollama本地大模型适配器 - 兼容OpenAI格式
"""
from typing import Dict, Any
from ..base_adapter import BaseLLMAdapter
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
class OllamaAdapter(BaseLLMAdapter):
"""Ollama本地模型适配器"""
@property
def base_url(self) -> str:
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.OLLAMA, "http://localhost:11434/v1")
async def complete(self, request: LLMRequest) -> LLMResponse:
try:
# Ollama本地运行跳过API Key验证
return await self.retry(lambda: self._send_request(request))
except Exception as error:
self.handle_error(error, "Ollama API调用失败")
async def _send_request(self, request: LLMRequest) -> LLMResponse:
# Ollama兼容OpenAI格式
headers = {}
if self.config.api_key:
headers["Authorization"] = f"Bearer {self.config.api_key}"
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
request_body: Dict[str, Any] = {
"model": self.config.model,
"messages": messages,
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
}
# Ollama的max_tokens参数名可能不同
if request.max_tokens or self.config.max_tokens:
request_body["num_predict"] = request.max_tokens or self.config.max_tokens
url = f"{self.base_url.rstrip('/')}/chat/completions"
response = await self.client.post(
url,
headers=self.build_headers(headers) if headers else self.build_headers(),
json=request_body
)
if response.status_code != 200:
error_data = response.json() if response.text else {}
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
raise Exception(f"{error_msg}")
data = response.json()
choice = data.get("choices", [{}])[0]
if not choice:
raise Exception("API响应格式异常: 缺少choices字段")
usage = None
if "usage" in data:
usage = LLMUsage(
prompt_tokens=data["usage"].get("prompt_tokens", 0),
completion_tokens=data["usage"].get("completion_tokens", 0),
total_tokens=data["usage"].get("total_tokens", 0)
)
return LLMResponse(
content=choice.get("message", {}).get("content", ""),
model=data.get("model"),
usage=usage,
finish_reason=choice.get("finish_reason")
)
async def validate_config(self) -> bool:
# Ollama本地运行不需要API Key
if not self.config.model:
raise Exception("未指定Ollama模型")
return True

View File

@ -1,93 +0,0 @@
"""
OpenAI适配器 (支持GPT系列和OpenAI兼容API)
"""
from typing import Dict, Any
from ..base_adapter import BaseLLMAdapter
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
class OpenAIAdapter(BaseLLMAdapter):
"""OpenAI适配器"""
@property
def base_url(self) -> str:
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.OPENAI, "https://api.openai.com/v1")
async def complete(self, request: LLMRequest) -> LLMResponse:
try:
await self.validate_config()
return await self.retry(lambda: self._send_request(request))
except Exception as error:
self.handle_error(error, "OpenAI API调用失败")
async def _send_request(self, request: LLMRequest) -> LLMResponse:
# 构建请求头
headers = {
"Authorization": f"Bearer {self.config.api_key}",
}
# 检测是否为推理模型o1/o3系列
model_name = self.config.model.lower()
is_reasoning_model = "o1" in model_name or "o3" in model_name
# 构建请求体
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
request_body: Dict[str, Any] = {
"model": self.config.model,
"messages": messages,
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
"frequency_penalty": self.config.frequency_penalty,
"presence_penalty": self.config.presence_penalty,
}
# 推理模型使用max_completion_tokens其他模型使用max_tokens
max_tokens = request.max_tokens if request.max_tokens is not None else self.config.max_tokens
if is_reasoning_model:
request_body["max_completion_tokens"] = max_tokens
else:
request_body["max_tokens"] = max_tokens
url = f"{self.base_url.rstrip('/')}/chat/completions"
response = await self.client.post(
url,
headers=self.build_headers(headers),
json=request_body
)
if response.status_code != 200:
error_data = response.json() if response.text else {}
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
raise Exception(f"{error_msg}")
data = response.json()
choice = data.get("choices", [{}])[0]
if not choice:
raise Exception("API响应格式异常: 缺少choices字段")
usage = None
if "usage" in data:
usage = LLMUsage(
prompt_tokens=data["usage"].get("prompt_tokens", 0),
completion_tokens=data["usage"].get("completion_tokens", 0),
total_tokens=data["usage"].get("total_tokens", 0)
)
return LLMResponse(
content=choice.get("message", {}).get("content", ""),
model=data.get("model"),
usage=usage,
finish_reason=choice.get("finish_reason")
)
async def validate_config(self) -> bool:
await super().validate_config()
if not self.config.model:
raise Exception("未指定OpenAI模型")
return True

View File

@ -1,80 +0,0 @@
"""
阿里云通义千问适配器 - 兼容OpenAI格式
"""
from typing import Dict, Any
from ..base_adapter import BaseLLMAdapter
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
class QwenAdapter(BaseLLMAdapter):
"""通义千问适配器"""
@property
def base_url(self) -> str:
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.QWEN, "https://dashscope.aliyuncs.com/compatible-mode/v1")
async def complete(self, request: LLMRequest) -> LLMResponse:
try:
await self.validate_config()
return await self.retry(lambda: self._send_request(request))
except Exception as error:
self.handle_error(error, "通义千问 API调用失败")
async def _send_request(self, request: LLMRequest) -> LLMResponse:
# 通义千问兼容OpenAI格式
headers = {
"Authorization": f"Bearer {self.config.api_key}",
}
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
request_body: Dict[str, Any] = {
"model": self.config.model,
"messages": messages,
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
}
url = f"{self.base_url.rstrip('/')}/chat/completions"
response = await self.client.post(
url,
headers=self.build_headers(headers),
json=request_body
)
if response.status_code != 200:
error_data = response.json() if response.text else {}
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
raise Exception(f"{error_msg}")
data = response.json()
choice = data.get("choices", [{}])[0]
if not choice:
raise Exception("API响应格式异常: 缺少choices字段")
usage = None
if "usage" in data:
usage = LLMUsage(
prompt_tokens=data["usage"].get("prompt_tokens", 0),
completion_tokens=data["usage"].get("completion_tokens", 0),
total_tokens=data["usage"].get("total_tokens", 0)
)
return LLMResponse(
content=choice.get("message", {}).get("content", ""),
model=data.get("model"),
usage=usage,
finish_reason=choice.get("finish_reason")
)
async def validate_config(self) -> bool:
await super().validate_config()
if not self.config.model:
raise Exception("未指定通义千问模型")
return True

View File

@ -1,80 +0,0 @@
"""
智谱AI适配器 (GLM系列) - 兼容OpenAI格式
"""
from typing import Dict, Any
from ..base_adapter import BaseLLMAdapter
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
class ZhipuAdapter(BaseLLMAdapter):
"""智谱AI适配器"""
@property
def base_url(self) -> str:
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.ZHIPU, "https://open.bigmodel.cn/api/paas/v4")
async def complete(self, request: LLMRequest) -> LLMResponse:
try:
await self.validate_config()
return await self.retry(lambda: self._send_request(request))
except Exception as error:
self.handle_error(error, "智谱AI API调用失败")
async def _send_request(self, request: LLMRequest) -> LLMResponse:
# 智谱AI兼容OpenAI格式
headers = {
"Authorization": f"Bearer {self.config.api_key}",
}
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
request_body: Dict[str, Any] = {
"model": self.config.model,
"messages": messages,
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
}
url = f"{self.base_url.rstrip('/')}/chat/completions"
response = await self.client.post(
url,
headers=self.build_headers(headers),
json=request_body
)
if response.status_code != 200:
error_data = response.json() if response.text else {}
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
raise Exception(f"{error_msg}")
data = response.json()
choice = data.get("choices", [{}])[0]
if not choice:
raise Exception("API响应格式异常: 缺少choices字段")
usage = None
if "usage" in data:
usage = LLMUsage(
prompt_tokens=data["usage"].get("prompt_tokens", 0),
completion_tokens=data["usage"].get("completion_tokens", 0),
total_tokens=data["usage"].get("total_tokens", 0)
)
return LLMResponse(
content=choice.get("message", {}).get("content", ""),
model=data.get("model"),
usage=usage,
finish_reason=choice.get("finish_reason")
)
async def validate_config(self) -> bool:
await super().validate_config()
if not self.config.model:
raise Exception("未指定智谱AI模型")
return True

View File

@ -1,95 +1,105 @@
""" """
LLM工厂类 - 统一创建和管理LLM适配器 LLM工厂类 - 统一创建和管理LLM适配器
使用 LiteLLM 作为主要适配器支持大多数 LLM 提供商
对于 API 格式特殊的提供商百度MiniMax豆包使用原生适配器
""" """
from typing import Dict, List, Optional from typing import Dict, List
from .types import LLMConfig, LLMProvider, DEFAULT_MODELS from .types import LLMConfig, LLMProvider, DEFAULT_MODELS
from .base_adapter import BaseLLMAdapter from .base_adapter import BaseLLMAdapter
from .adapters import ( from .adapters import (
OpenAIAdapter, LiteLLMAdapter,
GeminiAdapter,
ClaudeAdapter,
DeepSeekAdapter,
QwenAdapter,
ZhipuAdapter,
MoonshotAdapter,
BaiduAdapter, BaiduAdapter,
MinimaxAdapter, MinimaxAdapter,
DoubaoAdapter, DoubaoAdapter,
OllamaAdapter,
) )
# 必须使用原生适配器的提供商API 格式特殊)
NATIVE_ONLY_PROVIDERS = {
LLMProvider.BAIDU,
LLMProvider.MINIMAX,
LLMProvider.DOUBAO,
}
class LLMFactory: class LLMFactory:
"""LLM工厂类""" """LLM工厂类"""
_adapters: Dict[str, BaseLLMAdapter] = {} _adapters: Dict[str, BaseLLMAdapter] = {}
@classmethod @classmethod
def create_adapter(cls, config: LLMConfig) -> BaseLLMAdapter: def create_adapter(cls, config: LLMConfig) -> BaseLLMAdapter:
"""创建LLM适配器实例""" """创建LLM适配器实例"""
cache_key = cls._get_cache_key(config) cache_key = cls._get_cache_key(config)
# 从缓存中获取 # 从缓存中获取
if cache_key in cls._adapters: if cache_key in cls._adapters:
return cls._adapters[cache_key] return cls._adapters[cache_key]
# 创建新的适配器实例 # 创建新的适配器实例
adapter = cls._instantiate_adapter(config) adapter = cls._instantiate_adapter(config)
# 缓存实例 # 缓存实例
cls._adapters[cache_key] = adapter cls._adapters[cache_key] = adapter
return adapter return adapter
@classmethod @classmethod
def _instantiate_adapter(cls, config: LLMConfig) -> BaseLLMAdapter: def _instantiate_adapter(cls, config: LLMConfig) -> BaseLLMAdapter:
"""根据提供商类型实例化适配器""" """根据提供商类型实例化适配器"""
# 如果未指定模型,使用默认模型 # 如果未指定模型,使用默认模型
if not config.model: if not config.model:
config.model = DEFAULT_MODELS.get(config.provider, "gpt-4o-mini") config.model = DEFAULT_MODELS.get(config.provider, "gpt-4o-mini")
adapter_map = { # 对于必须使用原生适配器的提供商
LLMProvider.OPENAI: OpenAIAdapter, if config.provider in NATIVE_ONLY_PROVIDERS:
LLMProvider.GEMINI: GeminiAdapter, return cls._create_native_adapter(config)
LLMProvider.CLAUDE: ClaudeAdapter,
LLMProvider.DEEPSEEK: DeepSeekAdapter, # 其他提供商使用 LiteLLM
LLMProvider.QWEN: QwenAdapter, if LiteLLMAdapter.supports_provider(config.provider):
LLMProvider.ZHIPU: ZhipuAdapter, return LiteLLMAdapter(config)
LLMProvider.MOONSHOT: MoonshotAdapter,
# 不支持的提供商
raise ValueError(f"不支持的LLM提供商: {config.provider}")
@classmethod
def _create_native_adapter(cls, config: LLMConfig) -> BaseLLMAdapter:
"""创建原生适配器(仅用于 API 格式特殊的提供商)"""
native_adapter_map = {
LLMProvider.BAIDU: BaiduAdapter, LLMProvider.BAIDU: BaiduAdapter,
LLMProvider.MINIMAX: MinimaxAdapter, LLMProvider.MINIMAX: MinimaxAdapter,
LLMProvider.DOUBAO: DoubaoAdapter, LLMProvider.DOUBAO: DoubaoAdapter,
LLMProvider.OLLAMA: OllamaAdapter,
} }
adapter_class = adapter_map.get(config.provider) adapter_class = native_adapter_map.get(config.provider)
if not adapter_class: if not adapter_class:
raise ValueError(f"不支持的LLM提供商: {config.provider}") raise ValueError(f"不支持的原生适配器提供商: {config.provider}")
return adapter_class(config) return adapter_class(config)
@classmethod @classmethod
def _get_cache_key(cls, config: LLMConfig) -> str: def _get_cache_key(cls, config: LLMConfig) -> str:
"""生成缓存键""" """生成缓存键"""
api_key_prefix = config.api_key[:8] if config.api_key else "no-key" api_key_prefix = config.api_key[:8] if config.api_key else "no-key"
return f"{config.provider.value}:{config.model}:{api_key_prefix}" return f"{config.provider.value}:{config.model}:{api_key_prefix}"
@classmethod @classmethod
def clear_cache(cls) -> None: def clear_cache(cls) -> None:
"""清除缓存""" """清除缓存"""
cls._adapters.clear() cls._adapters.clear()
@classmethod @classmethod
def get_supported_providers(cls) -> List[LLMProvider]: def get_supported_providers(cls) -> List[LLMProvider]:
"""获取支持的提供商列表""" """获取支持的提供商列表"""
return list(LLMProvider) return list(LLMProvider)
@classmethod @classmethod
def get_default_model(cls, provider: LLMProvider) -> str: def get_default_model(cls, provider: LLMProvider) -> str:
"""获取提供商的默认模型""" """获取提供商的默认模型"""
return DEFAULT_MODELS.get(provider, "gpt-4o-mini") return DEFAULT_MODELS.get(provider, "gpt-4o-mini")
@classmethod @classmethod
def get_available_models(cls, provider: LLMProvider) -> List[str]: def get_available_models(cls, provider: LLMProvider) -> List[str]:
"""获取提供商的可用模型列表""" """获取提供商的可用模型列表"""
@ -162,4 +172,3 @@ class LLMFactory:
], ],
} }
return models.get(provider, []) return models.get(provider, [])

View File

@ -16,6 +16,9 @@ ACCESS_TOKEN_EXPIRE_MINUTES=11520
# ------------ LLM配置 ------------ # ------------ LLM配置 ------------
# 支持的provider: openai, gemini, claude, qwen, deepseek, zhipu, moonshot, baidu, minimax, doubao, ollama # 支持的provider: openai, gemini, claude, qwen, deepseek, zhipu, moonshot, baidu, minimax, doubao, ollama
#
# 使用 LiteLLM 统一适配器: openai, gemini, claude, qwen, deepseek, zhipu, moonshot, ollama
# 使用原生适配器 (API格式特殊): baidu, minimax, doubao
LLM_PROVIDER=openai LLM_PROVIDER=openai
LLM_API_KEY=sk-your-api-key LLM_API_KEY=sk-your-api-key
LLM_MODEL= LLM_MODEL=

View File

@ -18,4 +18,5 @@ dependencies = [
"email-validator", "email-validator",
"greenlet", "greenlet",
"bcrypt<5.0.0", "bcrypt<5.0.0",
"litellm>=1.0.0",
] ]

File diff suppressed because it is too large Load Diff

View File

@ -273,6 +273,34 @@ export const api = {
await apiClient.delete('/config/me'); await apiClient.delete('/config/me');
}, },
async testLLMConnection(params: {
provider: string;
apiKey: string;
model?: string;
baseUrl?: string;
}): Promise<{
success: boolean;
message: string;
model?: string;
response?: string;
}> {
const res = await apiClient.post('/config/test-llm', params);
return res.data;
},
async getLLMProviders(): Promise<{
providers: Array<{
id: string;
name: string;
defaultModel: string;
models: string[];
defaultBaseUrl: string;
}>;
}> {
const res = await apiClient.get('/config/llm-providers');
return res.data;
},
// ==================== 数据库管理相关方法 ==================== // ==================== 数据库管理相关方法 ====================
async exportDatabase(): Promise<{ async exportDatabase(): Promise<{