refactor(llm): consolidate LLM adapters with LiteLLM unified layer
- Replace individual adapter implementations (OpenAI, Claude, Gemini, DeepSeek, Qwen, Zhipu, Moonshot, Ollama) with unified LiteLLM adapter - Keep native adapters for providers with special API formats (Baidu, MiniMax, Doubao) - Update LLM factory to route requests through LiteLLM for supported providers - Add test-llm endpoint to validate LLM connections with configurable timeout and token limits - Add get-llm-providers endpoint to retrieve supported providers and their configurations - Update config.py to ignore extra environment variables (VITE_* frontend variables) - Refactor Baidu adapter to use new complete() method signature and improve error handling - Update pyproject.toml dependencies to include litellm package - Update env.example with new configuration options - Simplify adapter initialization and reduce code duplication across multiple provider implementations
This commit is contained in:
parent
1fc0ecd14a
commit
22c528acf1
|
|
@ -216,3 +216,109 @@ async def delete_my_config(
|
||||||
|
|
||||||
return {"message": "配置已删除"}
|
return {"message": "配置已删除"}
|
||||||
|
|
||||||
|
|
||||||
|
class LLMTestRequest(BaseModel):
|
||||||
|
"""LLM测试请求"""
|
||||||
|
provider: str
|
||||||
|
apiKey: str
|
||||||
|
model: Optional[str] = None
|
||||||
|
baseUrl: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class LLMTestResponse(BaseModel):
|
||||||
|
"""LLM测试响应"""
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
model: Optional[str] = None
|
||||||
|
response: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/test-llm", response_model=LLMTestResponse)
|
||||||
|
async def test_llm_connection(
|
||||||
|
request: LLMTestRequest,
|
||||||
|
current_user: User = Depends(deps.get_current_user),
|
||||||
|
) -> Any:
|
||||||
|
"""测试LLM连接是否正常"""
|
||||||
|
from app.services.llm.factory import LLMFactory
|
||||||
|
from app.services.llm.types import LLMConfig, LLMProvider, LLMRequest, LLMMessage, DEFAULT_MODELS
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 解析provider
|
||||||
|
provider_map = {
|
||||||
|
'gemini': LLMProvider.GEMINI,
|
||||||
|
'openai': LLMProvider.OPENAI,
|
||||||
|
'claude': LLMProvider.CLAUDE,
|
||||||
|
'qwen': LLMProvider.QWEN,
|
||||||
|
'deepseek': LLMProvider.DEEPSEEK,
|
||||||
|
'zhipu': LLMProvider.ZHIPU,
|
||||||
|
'moonshot': LLMProvider.MOONSHOT,
|
||||||
|
'baidu': LLMProvider.BAIDU,
|
||||||
|
'minimax': LLMProvider.MINIMAX,
|
||||||
|
'doubao': LLMProvider.DOUBAO,
|
||||||
|
'ollama': LLMProvider.OLLAMA,
|
||||||
|
}
|
||||||
|
|
||||||
|
provider = provider_map.get(request.provider.lower())
|
||||||
|
if not provider:
|
||||||
|
return LLMTestResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"不支持的LLM提供商: {request.provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取默认模型
|
||||||
|
model = request.model or DEFAULT_MODELS.get(provider)
|
||||||
|
|
||||||
|
# 创建配置
|
||||||
|
config = LLMConfig(
|
||||||
|
provider=provider,
|
||||||
|
api_key=request.apiKey,
|
||||||
|
model=model,
|
||||||
|
base_url=request.baseUrl,
|
||||||
|
timeout=30, # 测试使用较短的超时时间
|
||||||
|
max_tokens=50, # 测试使用较少的token
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建适配器并测试
|
||||||
|
adapter = LLMFactory.create_adapter(config)
|
||||||
|
|
||||||
|
test_request = LLMRequest(
|
||||||
|
messages=[
|
||||||
|
LLMMessage(role="user", content="Say 'Hello' in one word.")
|
||||||
|
],
|
||||||
|
max_tokens=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await adapter.complete(test_request)
|
||||||
|
|
||||||
|
return LLMTestResponse(
|
||||||
|
success=True,
|
||||||
|
message="LLM连接测试成功",
|
||||||
|
model=model,
|
||||||
|
response=response.content[:100] if response.content else None
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return LLMTestResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"LLM连接测试失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/llm-providers")
|
||||||
|
async def get_llm_providers() -> Any:
|
||||||
|
"""获取支持的LLM提供商列表"""
|
||||||
|
from app.services.llm.factory import LLMFactory
|
||||||
|
from app.services.llm.types import LLMProvider, DEFAULT_BASE_URLS
|
||||||
|
|
||||||
|
providers = []
|
||||||
|
for provider in LLMFactory.get_supported_providers():
|
||||||
|
providers.append({
|
||||||
|
"id": provider.value,
|
||||||
|
"name": provider.value.upper(),
|
||||||
|
"defaultModel": LLMFactory.get_default_model(provider),
|
||||||
|
"models": LLMFactory.get_available_models(provider),
|
||||||
|
"defaultBaseUrl": DEFAULT_BASE_URLS.get(provider, ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"providers": providers}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,7 @@ class Settings(BaseSettings):
|
||||||
class Config:
|
class Config:
|
||||||
case_sensitive = True
|
case_sensitive = True
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
extra = "ignore" # 忽略额外的环境变量(如 VITE_* 前端变量)
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
|
||||||
|
|
@ -1,30 +1,22 @@
|
||||||
"""
|
"""
|
||||||
LLM适配器模块
|
LLM适配器模块
|
||||||
|
|
||||||
|
适配器分为两类:
|
||||||
|
1. LiteLLM 统一适配器 - 支持 OpenAI, Claude, Gemini, DeepSeek, Qwen, Zhipu, Moonshot, Ollama
|
||||||
|
2. 原生适配器 - 用于 API 格式特殊的提供商: Baidu, MiniMax, Doubao
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .openai_adapter import OpenAIAdapter
|
# LiteLLM 统一适配器
|
||||||
from .gemini_adapter import GeminiAdapter
|
from .litellm_adapter import LiteLLMAdapter
|
||||||
from .claude_adapter import ClaudeAdapter
|
|
||||||
from .deepseek_adapter import DeepSeekAdapter
|
# 原生适配器 (用于 API 格式特殊的提供商)
|
||||||
from .qwen_adapter import QwenAdapter
|
|
||||||
from .zhipu_adapter import ZhipuAdapter
|
|
||||||
from .moonshot_adapter import MoonshotAdapter
|
|
||||||
from .baidu_adapter import BaiduAdapter
|
from .baidu_adapter import BaiduAdapter
|
||||||
from .minimax_adapter import MinimaxAdapter
|
from .minimax_adapter import MinimaxAdapter
|
||||||
from .doubao_adapter import DoubaoAdapter
|
from .doubao_adapter import DoubaoAdapter
|
||||||
from .ollama_adapter import OllamaAdapter
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'OpenAIAdapter',
|
"LiteLLMAdapter",
|
||||||
'GeminiAdapter',
|
"BaiduAdapter",
|
||||||
'ClaudeAdapter',
|
"MinimaxAdapter",
|
||||||
'DeepSeekAdapter',
|
"DoubaoAdapter",
|
||||||
'QwenAdapter',
|
|
||||||
'ZhipuAdapter',
|
|
||||||
'MoonshotAdapter',
|
|
||||||
'BaiduAdapter',
|
|
||||||
'MinimaxAdapter',
|
|
||||||
'DoubaoAdapter',
|
|
||||||
'OllamaAdapter',
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,9 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import json
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from ..base_adapter import BaseLLMAdapter
|
from ..base_adapter import BaseLLMAdapter
|
||||||
from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError
|
from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError, LLMProvider, LLMUsage
|
||||||
|
|
||||||
|
|
||||||
class BaiduAdapter(BaseLLMAdapter):
|
class BaiduAdapter(BaseLLMAdapter):
|
||||||
|
|
@ -70,8 +69,16 @@ class BaiduAdapter(BaseLLMAdapter):
|
||||||
|
|
||||||
return self._access_token
|
return self._access_token
|
||||||
|
|
||||||
async def _do_complete(self, request: LLMRequest) -> LLMResponse:
|
async def complete(self, request: LLMRequest) -> LLMResponse:
|
||||||
"""执行实际的API调用"""
|
"""执行实际的API调用"""
|
||||||
|
try:
|
||||||
|
await self.validate_config()
|
||||||
|
return await self.retry(lambda: self._send_request(request))
|
||||||
|
except Exception as error:
|
||||||
|
self.handle_error(error, "百度文心一言 API调用失败")
|
||||||
|
|
||||||
|
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
||||||
|
"""发送请求"""
|
||||||
access_token = await self._get_access_token()
|
access_token = await self._get_access_token()
|
||||||
|
|
||||||
# 获取模型对应的API端点
|
# 获取模型对应的API端点
|
||||||
|
|
@ -84,55 +91,56 @@ class BaiduAdapter(BaseLLMAdapter):
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": request.temperature or self.config.temperature,
|
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
||||||
"top_p": request.top_p or self.config.top_p,
|
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
||||||
}
|
}
|
||||||
|
|
||||||
if request.max_tokens or self.config.max_tokens:
|
if request.max_tokens or self.config.max_tokens:
|
||||||
payload["max_output_tokens"] = request.max_tokens or self.config.max_tokens
|
payload["max_output_tokens"] = request.max_tokens or self.config.max_tokens
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=self.config.timeout) as client:
|
response = await self.client.post(
|
||||||
response = await client.post(
|
|
||||||
url,
|
url,
|
||||||
json=payload,
|
headers=self.build_headers(),
|
||||||
headers={"Content-Type": "application/json"}
|
json=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise LLMError(
|
error_data = response.json() if response.text else {}
|
||||||
f"百度API错误: {response.text}",
|
error_msg = error_data.get("error_msg", f"HTTP {response.status_code}")
|
||||||
provider="baidu",
|
raise Exception(f"{error_msg}")
|
||||||
status_code=response.status_code
|
|
||||||
)
|
|
||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
if "error_code" in data:
|
if "error_code" in data:
|
||||||
raise LLMError(
|
raise Exception(f"百度API错误: {data.get('error_msg', '未知错误')}")
|
||||||
f"百度API错误: {data.get('error_msg', '未知错误')}",
|
|
||||||
provider="baidu",
|
usage = None
|
||||||
status_code=data.get("error_code")
|
if "usage" in data:
|
||||||
|
usage = LLMUsage(
|
||||||
|
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
||||||
|
completion_tokens=data["usage"].get("completion_tokens", 0),
|
||||||
|
total_tokens=data["usage"].get("total_tokens", 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=data.get("result", ""),
|
content=data.get("result", ""),
|
||||||
model=model,
|
model=model,
|
||||||
usage=data.get("usage"),
|
usage=usage,
|
||||||
finish_reason=data.get("finish_reason")
|
finish_reason=data.get("finish_reason")
|
||||||
)
|
)
|
||||||
|
|
||||||
async def validate_config(self) -> bool:
|
async def validate_config(self) -> bool:
|
||||||
"""验证配置是否有效"""
|
"""验证配置是否有效"""
|
||||||
try:
|
if not self.config.api_key:
|
||||||
await self._get_access_token()
|
raise LLMError(
|
||||||
|
"API Key未配置",
|
||||||
|
provider=LLMProvider.BAIDU
|
||||||
|
)
|
||||||
|
if ":" not in self.config.api_key:
|
||||||
|
raise LLMError(
|
||||||
|
"百度API需要同时提供API Key和Secret Key,格式:api_key:secret_key",
|
||||||
|
provider=LLMProvider.BAIDU
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_provider(self) -> str:
|
|
||||||
return "baidu"
|
|
||||||
|
|
||||||
def get_model(self) -> str:
|
|
||||||
return self.config.model or "ERNIE-3.5-8K"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,94 +0,0 @@
|
||||||
"""
|
|
||||||
Anthropic Claude适配器
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, Any
|
|
||||||
from ..base_adapter import BaseLLMAdapter
|
|
||||||
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
|
|
||||||
|
|
||||||
|
|
||||||
class ClaudeAdapter(BaseLLMAdapter):
|
|
||||||
"""Claude适配器"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def base_url(self) -> str:
|
|
||||||
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.CLAUDE, "https://api.anthropic.com/v1")
|
|
||||||
|
|
||||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
|
||||||
try:
|
|
||||||
await self.validate_config()
|
|
||||||
return await self.retry(lambda: self._send_request(request))
|
|
||||||
except Exception as error:
|
|
||||||
self.handle_error(error, "Claude API调用失败")
|
|
||||||
|
|
||||||
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
|
||||||
# Claude API需要将system消息分离
|
|
||||||
system_message = None
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
for msg in request.messages:
|
|
||||||
if msg.role == "system":
|
|
||||||
system_message = msg.content
|
|
||||||
else:
|
|
||||||
messages.append({
|
|
||||||
"role": msg.role,
|
|
||||||
"content": msg.content
|
|
||||||
})
|
|
||||||
|
|
||||||
request_body: Dict[str, Any] = {
|
|
||||||
"model": self.config.model,
|
|
||||||
"messages": messages,
|
|
||||||
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens or 4096,
|
|
||||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
|
||||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
|
||||||
}
|
|
||||||
|
|
||||||
if system_message:
|
|
||||||
request_body["system"] = system_message
|
|
||||||
|
|
||||||
# 构建请求头
|
|
||||||
headers = {
|
|
||||||
"x-api-key": self.config.api_key,
|
|
||||||
"anthropic-version": "2023-06-01",
|
|
||||||
}
|
|
||||||
|
|
||||||
url = f"{self.base_url.rstrip('/')}/messages"
|
|
||||||
|
|
||||||
response = await self.client.post(
|
|
||||||
url,
|
|
||||||
headers=self.build_headers(headers),
|
|
||||||
json=request_body
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
error_data = response.json() if response.text else {}
|
|
||||||
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
|
|
||||||
raise Exception(f"{error_msg}")
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
if not data.get("content") or not data["content"][0]:
|
|
||||||
raise Exception("API响应格式异常: 缺少content字段")
|
|
||||||
|
|
||||||
usage = None
|
|
||||||
if "usage" in data:
|
|
||||||
usage = LLMUsage(
|
|
||||||
prompt_tokens=data["usage"].get("input_tokens", 0),
|
|
||||||
completion_tokens=data["usage"].get("output_tokens", 0),
|
|
||||||
total_tokens=data["usage"].get("input_tokens", 0) + data["usage"].get("output_tokens", 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
return LLMResponse(
|
|
||||||
content=data["content"][0].get("text", ""),
|
|
||||||
model=data.get("model"),
|
|
||||||
usage=usage,
|
|
||||||
finish_reason=data.get("stop_reason")
|
|
||||||
)
|
|
||||||
|
|
||||||
async def validate_config(self) -> bool:
|
|
||||||
await super().validate_config()
|
|
||||||
if not self.config.model.startswith("claude-"):
|
|
||||||
raise Exception(f"无效的Claude模型: {self.config.model}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,82 +0,0 @@
|
||||||
"""
|
|
||||||
DeepSeek适配器 - 兼容OpenAI格式
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, Any
|
|
||||||
from ..base_adapter import BaseLLMAdapter
|
|
||||||
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
|
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekAdapter(BaseLLMAdapter):
|
|
||||||
"""DeepSeek适配器"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def base_url(self) -> str:
|
|
||||||
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.DEEPSEEK, "https://api.deepseek.com")
|
|
||||||
|
|
||||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
|
||||||
try:
|
|
||||||
await self.validate_config()
|
|
||||||
return await self.retry(lambda: self._send_request(request))
|
|
||||||
except Exception as error:
|
|
||||||
self.handle_error(error, "DeepSeek API调用失败")
|
|
||||||
|
|
||||||
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
|
||||||
# DeepSeek API兼容OpenAI格式
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.config.api_key}",
|
|
||||||
}
|
|
||||||
|
|
||||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
|
||||||
|
|
||||||
request_body: Dict[str, Any] = {
|
|
||||||
"model": self.config.model,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
|
||||||
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
|
|
||||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
|
||||||
"frequency_penalty": self.config.frequency_penalty,
|
|
||||||
"presence_penalty": self.config.presence_penalty,
|
|
||||||
}
|
|
||||||
|
|
||||||
url = f"{self.base_url.rstrip('/')}/v1/chat/completions"
|
|
||||||
|
|
||||||
response = await self.client.post(
|
|
||||||
url,
|
|
||||||
headers=self.build_headers(headers),
|
|
||||||
json=request_body
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
error_data = response.json() if response.text else {}
|
|
||||||
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
|
|
||||||
raise Exception(f"{error_msg}")
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
choice = data.get("choices", [{}])[0]
|
|
||||||
|
|
||||||
if not choice:
|
|
||||||
raise Exception("API响应格式异常: 缺少choices字段")
|
|
||||||
|
|
||||||
usage = None
|
|
||||||
if "usage" in data:
|
|
||||||
usage = LLMUsage(
|
|
||||||
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
|
||||||
completion_tokens=data["usage"].get("completion_tokens", 0),
|
|
||||||
total_tokens=data["usage"].get("total_tokens", 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
return LLMResponse(
|
|
||||||
content=choice.get("message", {}).get("content", ""),
|
|
||||||
model=data.get("model"),
|
|
||||||
usage=usage,
|
|
||||||
finish_reason=choice.get("finish_reason")
|
|
||||||
)
|
|
||||||
|
|
||||||
async def validate_config(self) -> bool:
|
|
||||||
await super().validate_config()
|
|
||||||
if not self.config.model:
|
|
||||||
raise Exception("未指定DeepSeek模型")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2,9 +2,8 @@
|
||||||
字节跳动豆包适配器
|
字节跳动豆包适配器
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import httpx
|
|
||||||
from ..base_adapter import BaseLLMAdapter
|
from ..base_adapter import BaseLLMAdapter
|
||||||
from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError
|
from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError, LLMProvider, LLMUsage
|
||||||
|
|
||||||
|
|
||||||
class DoubaoAdapter(BaseLLMAdapter):
|
class DoubaoAdapter(BaseLLMAdapter):
|
||||||
|
|
@ -17,8 +16,16 @@ class DoubaoAdapter(BaseLLMAdapter):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self._base_url = config.base_url or "https://ark.cn-beijing.volces.com/api/v3"
|
self._base_url = config.base_url or "https://ark.cn-beijing.volces.com/api/v3"
|
||||||
|
|
||||||
async def _do_complete(self, request: LLMRequest) -> LLMResponse:
|
async def complete(self, request: LLMRequest) -> LLMResponse:
|
||||||
"""执行实际的API调用"""
|
"""执行实际的API调用"""
|
||||||
|
try:
|
||||||
|
await self.validate_config()
|
||||||
|
return await self.retry(lambda: self._send_request(request))
|
||||||
|
except Exception as error:
|
||||||
|
self.handle_error(error, "豆包 API调用失败")
|
||||||
|
|
||||||
|
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
||||||
|
"""发送请求"""
|
||||||
url = f"{self._base_url}/chat/completions"
|
url = f"{self._base_url}/chat/completions"
|
||||||
|
|
||||||
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
||||||
|
|
@ -26,63 +33,52 @@ class DoubaoAdapter(BaseLLMAdapter):
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.config.model or "doubao-pro-32k",
|
"model": self.config.model or "doubao-pro-32k",
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": request.temperature or self.config.temperature,
|
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
||||||
"top_p": request.top_p or self.config.top_p,
|
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
|
||||||
|
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
||||||
}
|
}
|
||||||
|
|
||||||
if request.max_tokens or self.config.max_tokens:
|
|
||||||
payload["max_tokens"] = request.max_tokens or self.config.max_tokens
|
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {self.config.api_key}",
|
"Authorization": f"Bearer {self.config.api_key}",
|
||||||
}
|
}
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=self.config.timeout) as client:
|
response = await self.client.post(
|
||||||
response = await client.post(url, json=payload, headers=headers)
|
url,
|
||||||
|
headers=self.build_headers(headers),
|
||||||
|
json=payload
|
||||||
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise LLMError(
|
error_data = response.json() if response.text else {}
|
||||||
f"豆包API错误: {response.text}",
|
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
|
||||||
provider="doubao",
|
raise Exception(f"{error_msg}")
|
||||||
status_code=response.status_code
|
|
||||||
)
|
|
||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
choice = data.get("choices", [{}])[0]
|
||||||
|
|
||||||
if "error" in data:
|
if not choice:
|
||||||
raise LLMError(
|
raise Exception("API响应格式异常: 缺少choices字段")
|
||||||
f"豆包API错误: {data['error'].get('message', '未知错误')}",
|
|
||||||
provider="doubao"
|
usage = None
|
||||||
|
if "usage" in data:
|
||||||
|
usage = LLMUsage(
|
||||||
|
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
||||||
|
completion_tokens=data["usage"].get("completion_tokens", 0),
|
||||||
|
total_tokens=data["usage"].get("total_tokens", 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
choices = data.get("choices", [])
|
|
||||||
if not choices:
|
|
||||||
raise LLMError("豆包API返回空响应", provider="doubao")
|
|
||||||
|
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=choices[0].get("message", {}).get("content", ""),
|
content=choice.get("message", {}).get("content", ""),
|
||||||
model=data.get("model", self.config.model or "doubao-pro-32k"),
|
model=data.get("model"),
|
||||||
usage=data.get("usage"),
|
usage=usage,
|
||||||
finish_reason=choices[0].get("finish_reason")
|
finish_reason=choice.get("finish_reason")
|
||||||
)
|
)
|
||||||
|
|
||||||
async def validate_config(self) -> bool:
|
async def validate_config(self) -> bool:
|
||||||
"""验证配置是否有效"""
|
"""验证配置是否有效"""
|
||||||
try:
|
await super().validate_config()
|
||||||
test_request = LLMRequest(
|
if not self.config.model:
|
||||||
messages=[{"role": "user", "content": "Hi"}],
|
raise LLMError("未指定豆包模型", provider=LLMProvider.DOUBAO)
|
||||||
max_tokens=10
|
|
||||||
)
|
|
||||||
await self._do_complete(test_request)
|
|
||||||
return True
|
return True
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_provider(self) -> str:
|
|
||||||
return "doubao"
|
|
||||||
|
|
||||||
def get_model(self) -> str:
|
|
||||||
return self.config.model or "doubao-pro-32k"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,114 +0,0 @@
|
||||||
"""
|
|
||||||
Google Gemini适配器 - 支持官方API和中转站
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, Any, List
|
|
||||||
from ..base_adapter import BaseLLMAdapter
|
|
||||||
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiAdapter(BaseLLMAdapter):
|
|
||||||
"""Gemini适配器"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def base_url(self) -> str:
|
|
||||||
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.GEMINI, "https://generativelanguage.googleapis.com/v1beta")
|
|
||||||
|
|
||||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
|
||||||
try:
|
|
||||||
await self.validate_config()
|
|
||||||
return await self.retry(lambda: self._generate_content(request))
|
|
||||||
except Exception as error:
|
|
||||||
self.handle_error(error, "Gemini API调用失败")
|
|
||||||
|
|
||||||
async def _generate_content(self, request: LLMRequest) -> LLMResponse:
|
|
||||||
# 转换消息格式为 Gemini 格式
|
|
||||||
contents: List[Dict[str, Any]] = []
|
|
||||||
system_content = ""
|
|
||||||
|
|
||||||
for msg in request.messages:
|
|
||||||
if msg.role == "system":
|
|
||||||
system_content = msg.content
|
|
||||||
else:
|
|
||||||
role = "model" if msg.role == "assistant" else "user"
|
|
||||||
contents.append({
|
|
||||||
"role": role,
|
|
||||||
"parts": [{"text": msg.content}]
|
|
||||||
})
|
|
||||||
|
|
||||||
# 将系统消息合并到第一条用户消息
|
|
||||||
if system_content and contents:
|
|
||||||
contents[0]["parts"][0]["text"] = f"{system_content}\n\n{contents[0]['parts'][0]['text']}"
|
|
||||||
|
|
||||||
# 构建请求体
|
|
||||||
request_body = {
|
|
||||||
"contents": contents,
|
|
||||||
"generationConfig": {
|
|
||||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
|
||||||
"maxOutputTokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
|
|
||||||
"topP": request.top_p if request.top_p is not None else self.config.top_p,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# API Key 在 URL 参数中
|
|
||||||
url = f"{self.base_url}/models/{self.config.model}:generateContent?key={self.config.api_key}"
|
|
||||||
|
|
||||||
response = await self.client.post(
|
|
||||||
url,
|
|
||||||
headers=self.build_headers(),
|
|
||||||
json=request_body
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
error_data = response.json() if response.text else {}
|
|
||||||
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
|
|
||||||
raise Exception(f"{error_msg}")
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
# 解析 Gemini 响应格式
|
|
||||||
candidates = data.get("candidates", [])
|
|
||||||
if not candidates:
|
|
||||||
# 检查是否有错误信息
|
|
||||||
if "error" in data:
|
|
||||||
error_msg = data["error"].get("message", "未知错误")
|
|
||||||
raise Exception(f"Gemini API错误: {error_msg}")
|
|
||||||
raise Exception("API响应格式异常: 缺少candidates字段")
|
|
||||||
|
|
||||||
candidate = candidates[0]
|
|
||||||
if not candidate or "content" not in candidate:
|
|
||||||
raise Exception("API响应格式异常: 缺少content字段")
|
|
||||||
|
|
||||||
text_parts = candidate.get("content", {}).get("parts", [])
|
|
||||||
if not text_parts:
|
|
||||||
raise Exception("API响应格式异常: content.parts为空")
|
|
||||||
|
|
||||||
text = "".join(part.get("text", "") for part in text_parts)
|
|
||||||
|
|
||||||
# 检查响应内容是否为空
|
|
||||||
if not text or not text.strip():
|
|
||||||
finish_reason = candidate.get("finishReason", "unknown")
|
|
||||||
raise Exception(f"Gemini返回空响应 - Finish Reason: {finish_reason}")
|
|
||||||
|
|
||||||
usage = None
|
|
||||||
if "usageMetadata" in data:
|
|
||||||
usage_data = data["usageMetadata"]
|
|
||||||
usage = LLMUsage(
|
|
||||||
prompt_tokens=usage_data.get("promptTokenCount", 0),
|
|
||||||
completion_tokens=usage_data.get("candidatesTokenCount", 0),
|
|
||||||
total_tokens=usage_data.get("totalTokenCount", 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
return LLMResponse(
|
|
||||||
content=text,
|
|
||||||
model=self.config.model,
|
|
||||||
usage=usage,
|
|
||||||
finish_reason=candidate.get("finishReason", "stop")
|
|
||||||
)
|
|
||||||
|
|
||||||
async def validate_config(self) -> bool:
|
|
||||||
await super().validate_config()
|
|
||||||
if not self.config.model.startswith("gemini-"):
|
|
||||||
raise Exception(f"无效的Gemini模型: {self.config.model}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,182 @@
|
||||||
|
"""
|
||||||
|
LiteLLM 统一适配器
|
||||||
|
支持通过 LiteLLM 调用多个 LLM 提供商,使用统一的 OpenAI 兼容格式
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from ..base_adapter import BaseLLMAdapter
|
||||||
|
from ..types import (
|
||||||
|
LLMConfig,
|
||||||
|
LLMRequest,
|
||||||
|
LLMResponse,
|
||||||
|
LLMUsage,
|
||||||
|
LLMProvider,
|
||||||
|
LLMError,
|
||||||
|
DEFAULT_BASE_URLS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMAdapter(BaseLLMAdapter):
|
||||||
|
"""
|
||||||
|
LiteLLM 统一适配器
|
||||||
|
|
||||||
|
支持的提供商:
|
||||||
|
- OpenAI (openai/gpt-4o-mini)
|
||||||
|
- Claude (anthropic/claude-3-5-sonnet-20241022)
|
||||||
|
- Gemini (gemini/gemini-1.5-flash)
|
||||||
|
- DeepSeek (deepseek/deepseek-chat)
|
||||||
|
- Qwen (qwen/qwen-turbo) - 通过 OpenAI 兼容模式
|
||||||
|
- Zhipu (zhipu/glm-4-flash) - 通过 OpenAI 兼容模式
|
||||||
|
- Moonshot (moonshot/moonshot-v1-8k) - 通过 OpenAI 兼容模式
|
||||||
|
- Ollama (ollama/llama3)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# LiteLLM 模型前缀映射
|
||||||
|
PROVIDER_PREFIX_MAP = {
|
||||||
|
LLMProvider.OPENAI: "openai",
|
||||||
|
LLMProvider.CLAUDE: "anthropic",
|
||||||
|
LLMProvider.GEMINI: "gemini",
|
||||||
|
LLMProvider.DEEPSEEK: "deepseek",
|
||||||
|
LLMProvider.QWEN: "openai", # 使用 OpenAI 兼容模式
|
||||||
|
LLMProvider.ZHIPU: "openai", # 使用 OpenAI 兼容模式
|
||||||
|
LLMProvider.MOONSHOT: "openai", # 使用 OpenAI 兼容模式
|
||||||
|
LLMProvider.OLLAMA: "ollama",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 需要自定义 base_url 的提供商
|
||||||
|
CUSTOM_BASE_URL_PROVIDERS = {
|
||||||
|
LLMProvider.QWEN,
|
||||||
|
LLMProvider.ZHIPU,
|
||||||
|
LLMProvider.MOONSHOT,
|
||||||
|
LLMProvider.DEEPSEEK,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, config: LLMConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self._litellm_model = self._get_litellm_model()
|
||||||
|
self._api_base = self._get_api_base()
|
||||||
|
|
||||||
|
def _get_litellm_model(self) -> str:
|
||||||
|
"""获取 LiteLLM 格式的模型名称"""
|
||||||
|
provider = self.config.provider
|
||||||
|
model = self.config.model
|
||||||
|
|
||||||
|
# 对于使用 OpenAI 兼容模式的提供商,直接使用模型名
|
||||||
|
if provider in self.CUSTOM_BASE_URL_PROVIDERS:
|
||||||
|
return model
|
||||||
|
|
||||||
|
# 对于原生支持的提供商,添加前缀
|
||||||
|
prefix = self.PROVIDER_PREFIX_MAP.get(provider, "openai")
|
||||||
|
|
||||||
|
# 检查模型名是否已经包含前缀
|
||||||
|
if "/" in model:
|
||||||
|
return model
|
||||||
|
|
||||||
|
return f"{prefix}/{model}"
|
||||||
|
|
||||||
|
def _get_api_base(self) -> Optional[str]:
|
||||||
|
"""获取 API 基础 URL"""
|
||||||
|
# 优先使用用户配置的 base_url
|
||||||
|
if self.config.base_url:
|
||||||
|
return self.config.base_url
|
||||||
|
|
||||||
|
# 对于需要自定义 base_url 的提供商,使用默认值
|
||||||
|
if self.config.provider in self.CUSTOM_BASE_URL_PROVIDERS:
|
||||||
|
return DEFAULT_BASE_URLS.get(self.config.provider)
|
||||||
|
|
||||||
|
# Ollama 使用本地地址
|
||||||
|
if self.config.provider == LLMProvider.OLLAMA:
|
||||||
|
return DEFAULT_BASE_URLS.get(LLMProvider.OLLAMA, "http://localhost:11434")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def complete(self, request: LLMRequest) -> LLMResponse:
|
||||||
|
"""使用 LiteLLM 发送请求"""
|
||||||
|
try:
|
||||||
|
await self.validate_config()
|
||||||
|
return await self.retry(lambda: self._send_request(request))
|
||||||
|
except Exception as error:
|
||||||
|
self.handle_error(error, f"LiteLLM ({self.config.provider.value}) API调用失败")
|
||||||
|
|
||||||
|
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
||||||
|
"""发送请求到 LiteLLM"""
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
# 构建消息
|
||||||
|
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||||
|
|
||||||
|
# 构建请求参数
|
||||||
|
kwargs: Dict[str, Any] = {
|
||||||
|
"model": self._litellm_model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
||||||
|
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
|
||||||
|
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 设置 API Key
|
||||||
|
if self.config.api_key and self.config.api_key != "ollama":
|
||||||
|
kwargs["api_key"] = self.config.api_key
|
||||||
|
|
||||||
|
# 设置 API Base URL
|
||||||
|
if self._api_base:
|
||||||
|
kwargs["api_base"] = self._api_base
|
||||||
|
|
||||||
|
# 设置超时
|
||||||
|
kwargs["timeout"] = self.config.timeout
|
||||||
|
|
||||||
|
# 对于 OpenAI 提供商,添加额外参数
|
||||||
|
if self.config.provider == LLMProvider.OPENAI:
|
||||||
|
kwargs["frequency_penalty"] = self.config.frequency_penalty
|
||||||
|
kwargs["presence_penalty"] = self.config.presence_penalty
|
||||||
|
|
||||||
|
# 调用 LiteLLM
|
||||||
|
response = await litellm.acompletion(**kwargs)
|
||||||
|
|
||||||
|
# 解析响应
|
||||||
|
choice = response.choices[0] if response.choices else None
|
||||||
|
if not choice:
|
||||||
|
raise LLMError("API响应格式异常: 缺少choices字段", self.config.provider)
|
||||||
|
|
||||||
|
usage = None
|
||||||
|
if hasattr(response, "usage") and response.usage:
|
||||||
|
usage = LLMUsage(
|
||||||
|
prompt_tokens=response.usage.prompt_tokens or 0,
|
||||||
|
completion_tokens=response.usage.completion_tokens or 0,
|
||||||
|
total_tokens=response.usage.total_tokens or 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=choice.message.content or "",
|
||||||
|
model=response.model,
|
||||||
|
usage=usage,
|
||||||
|
finish_reason=choice.finish_reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def validate_config(self) -> bool:
|
||||||
|
"""验证配置"""
|
||||||
|
# Ollama 不需要 API Key
|
||||||
|
if self.config.provider == LLMProvider.OLLAMA:
|
||||||
|
if not self.config.model:
|
||||||
|
raise LLMError("未指定 Ollama 模型", LLMProvider.OLLAMA)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 其他提供商需要 API Key
|
||||||
|
if not self.config.api_key:
|
||||||
|
raise LLMError(
|
||||||
|
f"API Key未配置 ({self.config.provider.value})",
|
||||||
|
self.config.provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.config.model:
|
||||||
|
raise LLMError(
|
||||||
|
f"未指定模型 ({self.config.provider.value})",
|
||||||
|
self.config.provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_provider(cls, provider: LLMProvider) -> bool:
|
||||||
|
"""检查是否支持指定的提供商"""
|
||||||
|
return provider in cls.PROVIDER_PREFIX_MAP
|
||||||
|
|
@ -2,9 +2,8 @@
|
||||||
MiniMax适配器
|
MiniMax适配器
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import httpx
|
|
||||||
from ..base_adapter import BaseLLMAdapter
|
from ..base_adapter import BaseLLMAdapter
|
||||||
from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError
|
from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError, LLMProvider, LLMUsage
|
||||||
|
|
||||||
|
|
||||||
class MinimaxAdapter(BaseLLMAdapter):
|
class MinimaxAdapter(BaseLLMAdapter):
|
||||||
|
|
@ -14,8 +13,16 @@ class MinimaxAdapter(BaseLLMAdapter):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self._base_url = config.base_url or "https://api.minimax.chat/v1"
|
self._base_url = config.base_url or "https://api.minimax.chat/v1"
|
||||||
|
|
||||||
async def _do_complete(self, request: LLMRequest) -> LLMResponse:
|
async def complete(self, request: LLMRequest) -> LLMResponse:
|
||||||
"""执行实际的API调用"""
|
"""执行实际的API调用"""
|
||||||
|
try:
|
||||||
|
await self.validate_config()
|
||||||
|
return await self.retry(lambda: self._send_request(request))
|
||||||
|
except Exception as error:
|
||||||
|
self.handle_error(error, "MiniMax API调用失败")
|
||||||
|
|
||||||
|
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
||||||
|
"""发送请求"""
|
||||||
url = f"{self._base_url}/text/chatcompletion_v2"
|
url = f"{self._base_url}/text/chatcompletion_v2"
|
||||||
|
|
||||||
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
||||||
|
|
@ -23,63 +30,58 @@ class MinimaxAdapter(BaseLLMAdapter):
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.config.model or "abab6.5-chat",
|
"model": self.config.model or "abab6.5-chat",
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": request.temperature or self.config.temperature,
|
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
||||||
"top_p": request.top_p or self.config.top_p,
|
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
|
||||||
|
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
||||||
}
|
}
|
||||||
|
|
||||||
if request.max_tokens or self.config.max_tokens:
|
|
||||||
payload["max_tokens"] = request.max_tokens or self.config.max_tokens
|
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {self.config.api_key}",
|
"Authorization": f"Bearer {self.config.api_key}",
|
||||||
}
|
}
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=self.config.timeout) as client:
|
response = await self.client.post(
|
||||||
response = await client.post(url, json=payload, headers=headers)
|
url,
|
||||||
|
headers=self.build_headers(headers),
|
||||||
|
json=payload
|
||||||
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise LLMError(
|
error_data = response.json() if response.text else {}
|
||||||
f"MiniMax API错误: {response.text}",
|
error_msg = error_data.get("base_resp", {}).get("status_msg", f"HTTP {response.status_code}")
|
||||||
provider="minimax",
|
raise Exception(f"{error_msg}")
|
||||||
status_code=response.status_code
|
|
||||||
)
|
|
||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
|
# MiniMax 特殊的错误处理
|
||||||
if data.get("base_resp", {}).get("status_code") != 0:
|
if data.get("base_resp", {}).get("status_code") != 0:
|
||||||
raise LLMError(
|
error_msg = data.get("base_resp", {}).get("status_msg", "未知错误")
|
||||||
f"MiniMax API错误: {data.get('base_resp', {}).get('status_msg', '未知错误')}",
|
raise Exception(f"MiniMax API错误: {error_msg}")
|
||||||
provider="minimax"
|
|
||||||
|
choice = data.get("choices", [{}])[0]
|
||||||
|
|
||||||
|
if not choice:
|
||||||
|
raise Exception("API响应格式异常: 缺少choices字段")
|
||||||
|
|
||||||
|
usage = None
|
||||||
|
if "usage" in data:
|
||||||
|
usage = LLMUsage(
|
||||||
|
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
||||||
|
completion_tokens=data["usage"].get("completion_tokens", 0),
|
||||||
|
total_tokens=data["usage"].get("total_tokens", 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
choices = data.get("choices", [])
|
|
||||||
if not choices:
|
|
||||||
raise LLMError("MiniMax API返回空响应", provider="minimax")
|
|
||||||
|
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=choices[0].get("message", {}).get("content", ""),
|
content=choice.get("message", {}).get("content", ""),
|
||||||
model=self.config.model or "abab6.5-chat",
|
model=data.get("model"),
|
||||||
usage=data.get("usage"),
|
usage=usage,
|
||||||
finish_reason=choices[0].get("finish_reason")
|
finish_reason=choice.get("finish_reason")
|
||||||
)
|
)
|
||||||
|
|
||||||
async def validate_config(self) -> bool:
|
async def validate_config(self) -> bool:
|
||||||
"""验证配置是否有效"""
|
"""验证配置是否有效"""
|
||||||
try:
|
await super().validate_config()
|
||||||
test_request = LLMRequest(
|
if not self.config.model:
|
||||||
messages=[{"role": "user", "content": "Hi"}],
|
raise LLMError("未指定MiniMax模型", provider=LLMProvider.MINIMAX)
|
||||||
max_tokens=10
|
|
||||||
)
|
|
||||||
await self._do_complete(test_request)
|
|
||||||
return True
|
return True
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_provider(self) -> str:
|
|
||||||
return "minimax"
|
|
||||||
|
|
||||||
def get_model(self) -> str:
|
|
||||||
return self.config.model or "abab6.5-chat"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,80 +0,0 @@
|
||||||
"""
|
|
||||||
月之暗面 Kimi适配器 - 兼容OpenAI格式
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, Any
|
|
||||||
from ..base_adapter import BaseLLMAdapter
|
|
||||||
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
|
|
||||||
|
|
||||||
|
|
||||||
class MoonshotAdapter(BaseLLMAdapter):
|
|
||||||
"""月之暗面Kimi适配器"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def base_url(self) -> str:
|
|
||||||
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.MOONSHOT, "https://api.moonshot.cn/v1")
|
|
||||||
|
|
||||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
|
||||||
try:
|
|
||||||
await self.validate_config()
|
|
||||||
return await self.retry(lambda: self._send_request(request))
|
|
||||||
except Exception as error:
|
|
||||||
self.handle_error(error, "Moonshot API调用失败")
|
|
||||||
|
|
||||||
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
|
||||||
# Moonshot API兼容OpenAI格式
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.config.api_key}",
|
|
||||||
}
|
|
||||||
|
|
||||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
|
||||||
|
|
||||||
request_body: Dict[str, Any] = {
|
|
||||||
"model": self.config.model,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
|
||||||
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
|
|
||||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
|
||||||
}
|
|
||||||
|
|
||||||
url = f"{self.base_url.rstrip('/')}/chat/completions"
|
|
||||||
|
|
||||||
response = await self.client.post(
|
|
||||||
url,
|
|
||||||
headers=self.build_headers(headers),
|
|
||||||
json=request_body
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
error_data = response.json() if response.text else {}
|
|
||||||
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
|
|
||||||
raise Exception(f"{error_msg}")
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
choice = data.get("choices", [{}])[0]
|
|
||||||
|
|
||||||
if not choice:
|
|
||||||
raise Exception("API响应格式异常: 缺少choices字段")
|
|
||||||
|
|
||||||
usage = None
|
|
||||||
if "usage" in data:
|
|
||||||
usage = LLMUsage(
|
|
||||||
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
|
||||||
completion_tokens=data["usage"].get("completion_tokens", 0),
|
|
||||||
total_tokens=data["usage"].get("total_tokens", 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
return LLMResponse(
|
|
||||||
content=choice.get("message", {}).get("content", ""),
|
|
||||||
model=data.get("model"),
|
|
||||||
usage=usage,
|
|
||||||
finish_reason=choice.get("finish_reason")
|
|
||||||
)
|
|
||||||
|
|
||||||
async def validate_config(self) -> bool:
|
|
||||||
await super().validate_config()
|
|
||||||
if not self.config.model:
|
|
||||||
raise Exception("未指定Moonshot模型")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,83 +0,0 @@
|
||||||
"""
|
|
||||||
Ollama本地大模型适配器 - 兼容OpenAI格式
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, Any
|
|
||||||
from ..base_adapter import BaseLLMAdapter
|
|
||||||
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
|
|
||||||
|
|
||||||
|
|
||||||
class OllamaAdapter(BaseLLMAdapter):
|
|
||||||
"""Ollama本地模型适配器"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def base_url(self) -> str:
|
|
||||||
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.OLLAMA, "http://localhost:11434/v1")
|
|
||||||
|
|
||||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
|
||||||
try:
|
|
||||||
# Ollama本地运行,跳过API Key验证
|
|
||||||
return await self.retry(lambda: self._send_request(request))
|
|
||||||
except Exception as error:
|
|
||||||
self.handle_error(error, "Ollama API调用失败")
|
|
||||||
|
|
||||||
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
|
||||||
# Ollama兼容OpenAI格式
|
|
||||||
headers = {}
|
|
||||||
if self.config.api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {self.config.api_key}"
|
|
||||||
|
|
||||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
|
||||||
|
|
||||||
request_body: Dict[str, Any] = {
|
|
||||||
"model": self.config.model,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
|
||||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Ollama的max_tokens参数名可能不同
|
|
||||||
if request.max_tokens or self.config.max_tokens:
|
|
||||||
request_body["num_predict"] = request.max_tokens or self.config.max_tokens
|
|
||||||
|
|
||||||
url = f"{self.base_url.rstrip('/')}/chat/completions"
|
|
||||||
|
|
||||||
response = await self.client.post(
|
|
||||||
url,
|
|
||||||
headers=self.build_headers(headers) if headers else self.build_headers(),
|
|
||||||
json=request_body
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
error_data = response.json() if response.text else {}
|
|
||||||
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
|
|
||||||
raise Exception(f"{error_msg}")
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
choice = data.get("choices", [{}])[0]
|
|
||||||
|
|
||||||
if not choice:
|
|
||||||
raise Exception("API响应格式异常: 缺少choices字段")
|
|
||||||
|
|
||||||
usage = None
|
|
||||||
if "usage" in data:
|
|
||||||
usage = LLMUsage(
|
|
||||||
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
|
||||||
completion_tokens=data["usage"].get("completion_tokens", 0),
|
|
||||||
total_tokens=data["usage"].get("total_tokens", 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
return LLMResponse(
|
|
||||||
content=choice.get("message", {}).get("content", ""),
|
|
||||||
model=data.get("model"),
|
|
||||||
usage=usage,
|
|
||||||
finish_reason=choice.get("finish_reason")
|
|
||||||
)
|
|
||||||
|
|
||||||
async def validate_config(self) -> bool:
|
|
||||||
# Ollama本地运行,不需要API Key
|
|
||||||
if not self.config.model:
|
|
||||||
raise Exception("未指定Ollama模型")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,93 +0,0 @@
|
||||||
"""
|
|
||||||
OpenAI适配器 (支持GPT系列和OpenAI兼容API)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, Any
|
|
||||||
from ..base_adapter import BaseLLMAdapter
|
|
||||||
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIAdapter(BaseLLMAdapter):
|
|
||||||
"""OpenAI适配器"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def base_url(self) -> str:
|
|
||||||
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.OPENAI, "https://api.openai.com/v1")
|
|
||||||
|
|
||||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
|
||||||
try:
|
|
||||||
await self.validate_config()
|
|
||||||
return await self.retry(lambda: self._send_request(request))
|
|
||||||
except Exception as error:
|
|
||||||
self.handle_error(error, "OpenAI API调用失败")
|
|
||||||
|
|
||||||
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
|
||||||
# 构建请求头
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.config.api_key}",
|
|
||||||
}
|
|
||||||
|
|
||||||
# 检测是否为推理模型(o1/o3系列)
|
|
||||||
model_name = self.config.model.lower()
|
|
||||||
is_reasoning_model = "o1" in model_name or "o3" in model_name
|
|
||||||
|
|
||||||
# 构建请求体
|
|
||||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
|
||||||
|
|
||||||
request_body: Dict[str, Any] = {
|
|
||||||
"model": self.config.model,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
|
||||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
|
||||||
"frequency_penalty": self.config.frequency_penalty,
|
|
||||||
"presence_penalty": self.config.presence_penalty,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 推理模型使用max_completion_tokens,其他模型使用max_tokens
|
|
||||||
max_tokens = request.max_tokens if request.max_tokens is not None else self.config.max_tokens
|
|
||||||
if is_reasoning_model:
|
|
||||||
request_body["max_completion_tokens"] = max_tokens
|
|
||||||
else:
|
|
||||||
request_body["max_tokens"] = max_tokens
|
|
||||||
|
|
||||||
url = f"{self.base_url.rstrip('/')}/chat/completions"
|
|
||||||
|
|
||||||
response = await self.client.post(
|
|
||||||
url,
|
|
||||||
headers=self.build_headers(headers),
|
|
||||||
json=request_body
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
error_data = response.json() if response.text else {}
|
|
||||||
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
|
|
||||||
raise Exception(f"{error_msg}")
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
choice = data.get("choices", [{}])[0]
|
|
||||||
|
|
||||||
if not choice:
|
|
||||||
raise Exception("API响应格式异常: 缺少choices字段")
|
|
||||||
|
|
||||||
usage = None
|
|
||||||
if "usage" in data:
|
|
||||||
usage = LLMUsage(
|
|
||||||
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
|
||||||
completion_tokens=data["usage"].get("completion_tokens", 0),
|
|
||||||
total_tokens=data["usage"].get("total_tokens", 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
return LLMResponse(
|
|
||||||
content=choice.get("message", {}).get("content", ""),
|
|
||||||
model=data.get("model"),
|
|
||||||
usage=usage,
|
|
||||||
finish_reason=choice.get("finish_reason")
|
|
||||||
)
|
|
||||||
|
|
||||||
async def validate_config(self) -> bool:
|
|
||||||
await super().validate_config()
|
|
||||||
if not self.config.model:
|
|
||||||
raise Exception("未指定OpenAI模型")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,80 +0,0 @@
|
||||||
"""
|
|
||||||
阿里云通义千问适配器 - 兼容OpenAI格式
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, Any
|
|
||||||
from ..base_adapter import BaseLLMAdapter
|
|
||||||
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
|
|
||||||
|
|
||||||
|
|
||||||
class QwenAdapter(BaseLLMAdapter):
|
|
||||||
"""通义千问适配器"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def base_url(self) -> str:
|
|
||||||
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.QWEN, "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
|
||||||
|
|
||||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
|
||||||
try:
|
|
||||||
await self.validate_config()
|
|
||||||
return await self.retry(lambda: self._send_request(request))
|
|
||||||
except Exception as error:
|
|
||||||
self.handle_error(error, "通义千问 API调用失败")
|
|
||||||
|
|
||||||
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
|
||||||
# 通义千问兼容OpenAI格式
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.config.api_key}",
|
|
||||||
}
|
|
||||||
|
|
||||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
|
||||||
|
|
||||||
request_body: Dict[str, Any] = {
|
|
||||||
"model": self.config.model,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
|
||||||
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
|
|
||||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
|
||||||
}
|
|
||||||
|
|
||||||
url = f"{self.base_url.rstrip('/')}/chat/completions"
|
|
||||||
|
|
||||||
response = await self.client.post(
|
|
||||||
url,
|
|
||||||
headers=self.build_headers(headers),
|
|
||||||
json=request_body
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
error_data = response.json() if response.text else {}
|
|
||||||
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
|
|
||||||
raise Exception(f"{error_msg}")
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
choice = data.get("choices", [{}])[0]
|
|
||||||
|
|
||||||
if not choice:
|
|
||||||
raise Exception("API响应格式异常: 缺少choices字段")
|
|
||||||
|
|
||||||
usage = None
|
|
||||||
if "usage" in data:
|
|
||||||
usage = LLMUsage(
|
|
||||||
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
|
||||||
completion_tokens=data["usage"].get("completion_tokens", 0),
|
|
||||||
total_tokens=data["usage"].get("total_tokens", 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
return LLMResponse(
|
|
||||||
content=choice.get("message", {}).get("content", ""),
|
|
||||||
model=data.get("model"),
|
|
||||||
usage=usage,
|
|
||||||
finish_reason=choice.get("finish_reason")
|
|
||||||
)
|
|
||||||
|
|
||||||
async def validate_config(self) -> bool:
|
|
||||||
await super().validate_config()
|
|
||||||
if not self.config.model:
|
|
||||||
raise Exception("未指定通义千问模型")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,80 +0,0 @@
|
||||||
"""
|
|
||||||
智谱AI适配器 (GLM系列) - 兼容OpenAI格式
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, Any
|
|
||||||
from ..base_adapter import BaseLLMAdapter
|
|
||||||
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
|
|
||||||
|
|
||||||
|
|
||||||
class ZhipuAdapter(BaseLLMAdapter):
|
|
||||||
"""智谱AI适配器"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def base_url(self) -> str:
|
|
||||||
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.ZHIPU, "https://open.bigmodel.cn/api/paas/v4")
|
|
||||||
|
|
||||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
|
||||||
try:
|
|
||||||
await self.validate_config()
|
|
||||||
return await self.retry(lambda: self._send_request(request))
|
|
||||||
except Exception as error:
|
|
||||||
self.handle_error(error, "智谱AI API调用失败")
|
|
||||||
|
|
||||||
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
|
||||||
# 智谱AI兼容OpenAI格式
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.config.api_key}",
|
|
||||||
}
|
|
||||||
|
|
||||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
|
||||||
|
|
||||||
request_body: Dict[str, Any] = {
|
|
||||||
"model": self.config.model,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
|
||||||
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
|
|
||||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
|
||||||
}
|
|
||||||
|
|
||||||
url = f"{self.base_url.rstrip('/')}/chat/completions"
|
|
||||||
|
|
||||||
response = await self.client.post(
|
|
||||||
url,
|
|
||||||
headers=self.build_headers(headers),
|
|
||||||
json=request_body
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
error_data = response.json() if response.text else {}
|
|
||||||
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
|
|
||||||
raise Exception(f"{error_msg}")
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
choice = data.get("choices", [{}])[0]
|
|
||||||
|
|
||||||
if not choice:
|
|
||||||
raise Exception("API响应格式异常: 缺少choices字段")
|
|
||||||
|
|
||||||
usage = None
|
|
||||||
if "usage" in data:
|
|
||||||
usage = LLMUsage(
|
|
||||||
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
|
||||||
completion_tokens=data["usage"].get("completion_tokens", 0),
|
|
||||||
total_tokens=data["usage"].get("total_tokens", 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
return LLMResponse(
|
|
||||||
content=choice.get("message", {}).get("content", ""),
|
|
||||||
model=data.get("model"),
|
|
||||||
usage=usage,
|
|
||||||
finish_reason=choice.get("finish_reason")
|
|
||||||
)
|
|
||||||
|
|
||||||
async def validate_config(self) -> bool:
|
|
||||||
await super().validate_config()
|
|
||||||
if not self.config.model:
|
|
||||||
raise Exception("未指定智谱AI模型")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,25 +1,29 @@
|
||||||
"""
|
"""
|
||||||
LLM工厂类 - 统一创建和管理LLM适配器
|
LLM工厂类 - 统一创建和管理LLM适配器
|
||||||
|
|
||||||
|
使用 LiteLLM 作为主要适配器,支持大多数 LLM 提供商。
|
||||||
|
对于 API 格式特殊的提供商(百度、MiniMax、豆包),使用原生适配器。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List
|
||||||
from .types import LLMConfig, LLMProvider, DEFAULT_MODELS
|
from .types import LLMConfig, LLMProvider, DEFAULT_MODELS
|
||||||
from .base_adapter import BaseLLMAdapter
|
from .base_adapter import BaseLLMAdapter
|
||||||
from .adapters import (
|
from .adapters import (
|
||||||
OpenAIAdapter,
|
LiteLLMAdapter,
|
||||||
GeminiAdapter,
|
|
||||||
ClaudeAdapter,
|
|
||||||
DeepSeekAdapter,
|
|
||||||
QwenAdapter,
|
|
||||||
ZhipuAdapter,
|
|
||||||
MoonshotAdapter,
|
|
||||||
BaiduAdapter,
|
BaiduAdapter,
|
||||||
MinimaxAdapter,
|
MinimaxAdapter,
|
||||||
DoubaoAdapter,
|
DoubaoAdapter,
|
||||||
OllamaAdapter,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 必须使用原生适配器的提供商(API 格式特殊)
|
||||||
|
NATIVE_ONLY_PROVIDERS = {
|
||||||
|
LLMProvider.BAIDU,
|
||||||
|
LLMProvider.MINIMAX,
|
||||||
|
LLMProvider.DOUBAO,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class LLMFactory:
|
class LLMFactory:
|
||||||
"""LLM工厂类"""
|
"""LLM工厂类"""
|
||||||
|
|
||||||
|
|
@ -49,23 +53,29 @@ class LLMFactory:
|
||||||
if not config.model:
|
if not config.model:
|
||||||
config.model = DEFAULT_MODELS.get(config.provider, "gpt-4o-mini")
|
config.model = DEFAULT_MODELS.get(config.provider, "gpt-4o-mini")
|
||||||
|
|
||||||
adapter_map = {
|
# 对于必须使用原生适配器的提供商
|
||||||
LLMProvider.OPENAI: OpenAIAdapter,
|
if config.provider in NATIVE_ONLY_PROVIDERS:
|
||||||
LLMProvider.GEMINI: GeminiAdapter,
|
return cls._create_native_adapter(config)
|
||||||
LLMProvider.CLAUDE: ClaudeAdapter,
|
|
||||||
LLMProvider.DEEPSEEK: DeepSeekAdapter,
|
# 其他提供商使用 LiteLLM
|
||||||
LLMProvider.QWEN: QwenAdapter,
|
if LiteLLMAdapter.supports_provider(config.provider):
|
||||||
LLMProvider.ZHIPU: ZhipuAdapter,
|
return LiteLLMAdapter(config)
|
||||||
LLMProvider.MOONSHOT: MoonshotAdapter,
|
|
||||||
|
# 不支持的提供商
|
||||||
|
raise ValueError(f"不支持的LLM提供商: {config.provider}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _create_native_adapter(cls, config: LLMConfig) -> BaseLLMAdapter:
|
||||||
|
"""创建原生适配器(仅用于 API 格式特殊的提供商)"""
|
||||||
|
native_adapter_map = {
|
||||||
LLMProvider.BAIDU: BaiduAdapter,
|
LLMProvider.BAIDU: BaiduAdapter,
|
||||||
LLMProvider.MINIMAX: MinimaxAdapter,
|
LLMProvider.MINIMAX: MinimaxAdapter,
|
||||||
LLMProvider.DOUBAO: DoubaoAdapter,
|
LLMProvider.DOUBAO: DoubaoAdapter,
|
||||||
LLMProvider.OLLAMA: OllamaAdapter,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
adapter_class = adapter_map.get(config.provider)
|
adapter_class = native_adapter_map.get(config.provider)
|
||||||
if not adapter_class:
|
if not adapter_class:
|
||||||
raise ValueError(f"不支持的LLM提供商: {config.provider}")
|
raise ValueError(f"不支持的原生适配器提供商: {config.provider}")
|
||||||
|
|
||||||
return adapter_class(config)
|
return adapter_class(config)
|
||||||
|
|
||||||
|
|
@ -162,4 +172,3 @@ class LLMFactory:
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
return models.get(provider, [])
|
return models.get(provider, [])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,9 @@ ACCESS_TOKEN_EXPIRE_MINUTES=11520
|
||||||
|
|
||||||
# ------------ LLM配置 ------------
|
# ------------ LLM配置 ------------
|
||||||
# 支持的provider: openai, gemini, claude, qwen, deepseek, zhipu, moonshot, baidu, minimax, doubao, ollama
|
# 支持的provider: openai, gemini, claude, qwen, deepseek, zhipu, moonshot, baidu, minimax, doubao, ollama
|
||||||
|
#
|
||||||
|
# 使用 LiteLLM 统一适配器: openai, gemini, claude, qwen, deepseek, zhipu, moonshot, ollama
|
||||||
|
# 使用原生适配器 (API格式特殊): baidu, minimax, doubao
|
||||||
LLM_PROVIDER=openai
|
LLM_PROVIDER=openai
|
||||||
LLM_API_KEY=sk-your-api-key
|
LLM_API_KEY=sk-your-api-key
|
||||||
LLM_MODEL=
|
LLM_MODEL=
|
||||||
|
|
|
||||||
|
|
@ -18,4 +18,5 @@ dependencies = [
|
||||||
"email-validator",
|
"email-validator",
|
||||||
"greenlet",
|
"greenlet",
|
||||||
"bcrypt<5.0.0",
|
"bcrypt<5.0.0",
|
||||||
|
"litellm>=1.0.0",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -273,6 +273,34 @@ export const api = {
|
||||||
await apiClient.delete('/config/me');
|
await apiClient.delete('/config/me');
|
||||||
},
|
},
|
||||||
|
|
||||||
|
async testLLMConnection(params: {
|
||||||
|
provider: string;
|
||||||
|
apiKey: string;
|
||||||
|
model?: string;
|
||||||
|
baseUrl?: string;
|
||||||
|
}): Promise<{
|
||||||
|
success: boolean;
|
||||||
|
message: string;
|
||||||
|
model?: string;
|
||||||
|
response?: string;
|
||||||
|
}> {
|
||||||
|
const res = await apiClient.post('/config/test-llm', params);
|
||||||
|
return res.data;
|
||||||
|
},
|
||||||
|
|
||||||
|
async getLLMProviders(): Promise<{
|
||||||
|
providers: Array<{
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
defaultModel: string;
|
||||||
|
models: string[];
|
||||||
|
defaultBaseUrl: string;
|
||||||
|
}>;
|
||||||
|
}> {
|
||||||
|
const res = await apiClient.get('/config/llm-providers');
|
||||||
|
return res.data;
|
||||||
|
},
|
||||||
|
|
||||||
// ==================== 数据库管理相关方法 ====================
|
// ==================== 数据库管理相关方法 ====================
|
||||||
|
|
||||||
async exportDatabase(): Promise<{
|
async exportDatabase(): Promise<{
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue