refactor(llm): consolidate LLM adapters with LiteLLM unified layer
- Replace individual adapter implementations (OpenAI, Claude, Gemini, DeepSeek, Qwen, Zhipu, Moonshot, Ollama) with unified LiteLLM adapter - Keep native adapters for providers with special API formats (Baidu, MiniMax, Doubao) - Update LLM factory to route requests through LiteLLM for supported providers - Add test-llm endpoint to validate LLM connections with configurable timeout and token limits - Add get-llm-providers endpoint to retrieve supported providers and their configurations - Update config.py to ignore extra environment variables (VITE_* frontend variables) - Refactor Baidu adapter to use new complete() method signature and improve error handling - Update pyproject.toml dependencies to include litellm package - Update env.example with new configuration options - Simplify adapter initialization and reduce code duplication across multiple provider implementations
This commit is contained in:
parent
1fc0ecd14a
commit
22c528acf1
|
|
@ -216,3 +216,109 @@ async def delete_my_config(
|
|||
|
||||
return {"message": "配置已删除"}
|
||||
|
||||
|
||||
class LLMTestRequest(BaseModel):
|
||||
"""LLM测试请求"""
|
||||
provider: str
|
||||
apiKey: str
|
||||
model: Optional[str] = None
|
||||
baseUrl: Optional[str] = None
|
||||
|
||||
|
||||
class LLMTestResponse(BaseModel):
|
||||
"""LLM测试响应"""
|
||||
success: bool
|
||||
message: str
|
||||
model: Optional[str] = None
|
||||
response: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/test-llm", response_model=LLMTestResponse)
|
||||
async def test_llm_connection(
|
||||
request: LLMTestRequest,
|
||||
current_user: User = Depends(deps.get_current_user),
|
||||
) -> Any:
|
||||
"""测试LLM连接是否正常"""
|
||||
from app.services.llm.factory import LLMFactory
|
||||
from app.services.llm.types import LLMConfig, LLMProvider, LLMRequest, LLMMessage, DEFAULT_MODELS
|
||||
|
||||
try:
|
||||
# 解析provider
|
||||
provider_map = {
|
||||
'gemini': LLMProvider.GEMINI,
|
||||
'openai': LLMProvider.OPENAI,
|
||||
'claude': LLMProvider.CLAUDE,
|
||||
'qwen': LLMProvider.QWEN,
|
||||
'deepseek': LLMProvider.DEEPSEEK,
|
||||
'zhipu': LLMProvider.ZHIPU,
|
||||
'moonshot': LLMProvider.MOONSHOT,
|
||||
'baidu': LLMProvider.BAIDU,
|
||||
'minimax': LLMProvider.MINIMAX,
|
||||
'doubao': LLMProvider.DOUBAO,
|
||||
'ollama': LLMProvider.OLLAMA,
|
||||
}
|
||||
|
||||
provider = provider_map.get(request.provider.lower())
|
||||
if not provider:
|
||||
return LLMTestResponse(
|
||||
success=False,
|
||||
message=f"不支持的LLM提供商: {request.provider}"
|
||||
)
|
||||
|
||||
# 获取默认模型
|
||||
model = request.model or DEFAULT_MODELS.get(provider)
|
||||
|
||||
# 创建配置
|
||||
config = LLMConfig(
|
||||
provider=provider,
|
||||
api_key=request.apiKey,
|
||||
model=model,
|
||||
base_url=request.baseUrl,
|
||||
timeout=30, # 测试使用较短的超时时间
|
||||
max_tokens=50, # 测试使用较少的token
|
||||
)
|
||||
|
||||
# 创建适配器并测试
|
||||
adapter = LLMFactory.create_adapter(config)
|
||||
|
||||
test_request = LLMRequest(
|
||||
messages=[
|
||||
LLMMessage(role="user", content="Say 'Hello' in one word.")
|
||||
],
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
response = await adapter.complete(test_request)
|
||||
|
||||
return LLMTestResponse(
|
||||
success=True,
|
||||
message="LLM连接测试成功",
|
||||
model=model,
|
||||
response=response.content[:100] if response.content else None
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return LLMTestResponse(
|
||||
success=False,
|
||||
message=f"LLM连接测试失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/llm-providers")
|
||||
async def get_llm_providers() -> Any:
|
||||
"""获取支持的LLM提供商列表"""
|
||||
from app.services.llm.factory import LLMFactory
|
||||
from app.services.llm.types import LLMProvider, DEFAULT_BASE_URLS
|
||||
|
||||
providers = []
|
||||
for provider in LLMFactory.get_supported_providers():
|
||||
providers.append({
|
||||
"id": provider.value,
|
||||
"name": provider.value.upper(),
|
||||
"defaultModel": LLMFactory.get_default_model(provider),
|
||||
"models": LLMFactory.get_available_models(provider),
|
||||
"defaultBaseUrl": DEFAULT_BASE_URLS.get(provider, ""),
|
||||
})
|
||||
|
||||
return {"providers": providers}
|
||||
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ class Settings(BaseSettings):
|
|||
class Config:
|
||||
case_sensitive = True
|
||||
env_file = ".env"
|
||||
extra = "ignore" # 忽略额外的环境变量(如 VITE_* 前端变量)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
|
|
|||
|
|
@ -1,30 +1,22 @@
|
|||
"""
|
||||
LLM适配器模块
|
||||
|
||||
适配器分为两类:
|
||||
1. LiteLLM 统一适配器 - 支持 OpenAI, Claude, Gemini, DeepSeek, Qwen, Zhipu, Moonshot, Ollama
|
||||
2. 原生适配器 - 用于 API 格式特殊的提供商: Baidu, MiniMax, Doubao
|
||||
"""
|
||||
|
||||
from .openai_adapter import OpenAIAdapter
|
||||
from .gemini_adapter import GeminiAdapter
|
||||
from .claude_adapter import ClaudeAdapter
|
||||
from .deepseek_adapter import DeepSeekAdapter
|
||||
from .qwen_adapter import QwenAdapter
|
||||
from .zhipu_adapter import ZhipuAdapter
|
||||
from .moonshot_adapter import MoonshotAdapter
|
||||
# LiteLLM 统一适配器
|
||||
from .litellm_adapter import LiteLLMAdapter
|
||||
|
||||
# 原生适配器 (用于 API 格式特殊的提供商)
|
||||
from .baidu_adapter import BaiduAdapter
|
||||
from .minimax_adapter import MinimaxAdapter
|
||||
from .doubao_adapter import DoubaoAdapter
|
||||
from .ollama_adapter import OllamaAdapter
|
||||
|
||||
__all__ = [
|
||||
'OpenAIAdapter',
|
||||
'GeminiAdapter',
|
||||
'ClaudeAdapter',
|
||||
'DeepSeekAdapter',
|
||||
'QwenAdapter',
|
||||
'ZhipuAdapter',
|
||||
'MoonshotAdapter',
|
||||
'BaiduAdapter',
|
||||
'MinimaxAdapter',
|
||||
'DoubaoAdapter',
|
||||
'OllamaAdapter',
|
||||
"LiteLLMAdapter",
|
||||
"BaiduAdapter",
|
||||
"MinimaxAdapter",
|
||||
"DoubaoAdapter",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -3,10 +3,9 @@
|
|||
"""
|
||||
|
||||
import httpx
|
||||
import json
|
||||
from typing import Optional
|
||||
from ..base_adapter import BaseLLMAdapter
|
||||
from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError
|
||||
from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError, LLMProvider, LLMUsage
|
||||
|
||||
|
||||
class BaiduAdapter(BaseLLMAdapter):
|
||||
|
|
@ -70,8 +69,16 @@ class BaiduAdapter(BaseLLMAdapter):
|
|||
|
||||
return self._access_token
|
||||
|
||||
async def _do_complete(self, request: LLMRequest) -> LLMResponse:
|
||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
||||
"""执行实际的API调用"""
|
||||
try:
|
||||
await self.validate_config()
|
||||
return await self.retry(lambda: self._send_request(request))
|
||||
except Exception as error:
|
||||
self.handle_error(error, "百度文心一言 API调用失败")
|
||||
|
||||
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
||||
"""发送请求"""
|
||||
access_token = await self._get_access_token()
|
||||
|
||||
# 获取模型对应的API端点
|
||||
|
|
@ -84,55 +91,56 @@ class BaiduAdapter(BaseLLMAdapter):
|
|||
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"temperature": request.temperature or self.config.temperature,
|
||||
"top_p": request.top_p or self.config.top_p,
|
||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
||||
}
|
||||
|
||||
if request.max_tokens or self.config.max_tokens:
|
||||
payload["max_output_tokens"] = request.max_tokens or self.config.max_tokens
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.config.timeout) as client:
|
||||
response = await client.post(
|
||||
response = await self.client.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json"}
|
||||
headers=self.build_headers(),
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise LLMError(
|
||||
f"百度API错误: {response.text}",
|
||||
provider="baidu",
|
||||
status_code=response.status_code
|
||||
)
|
||||
error_data = response.json() if response.text else {}
|
||||
error_msg = error_data.get("error_msg", f"HTTP {response.status_code}")
|
||||
raise Exception(f"{error_msg}")
|
||||
|
||||
data = response.json()
|
||||
|
||||
if "error_code" in data:
|
||||
raise LLMError(
|
||||
f"百度API错误: {data.get('error_msg', '未知错误')}",
|
||||
provider="baidu",
|
||||
status_code=data.get("error_code")
|
||||
raise Exception(f"百度API错误: {data.get('error_msg', '未知错误')}")
|
||||
|
||||
usage = None
|
||||
if "usage" in data:
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
||||
completion_tokens=data["usage"].get("completion_tokens", 0),
|
||||
total_tokens=data["usage"].get("total_tokens", 0)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=data.get("result", ""),
|
||||
model=model,
|
||||
usage=data.get("usage"),
|
||||
usage=usage,
|
||||
finish_reason=data.get("finish_reason")
|
||||
)
|
||||
|
||||
async def validate_config(self) -> bool:
|
||||
"""验证配置是否有效"""
|
||||
try:
|
||||
await self._get_access_token()
|
||||
if not self.config.api_key:
|
||||
raise LLMError(
|
||||
"API Key未配置",
|
||||
provider=LLMProvider.BAIDU
|
||||
)
|
||||
if ":" not in self.config.api_key:
|
||||
raise LLMError(
|
||||
"百度API需要同时提供API Key和Secret Key,格式:api_key:secret_key",
|
||||
provider=LLMProvider.BAIDU
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_provider(self) -> str:
|
||||
return "baidu"
|
||||
|
||||
def get_model(self) -> str:
|
||||
return self.config.model or "ERNIE-3.5-8K"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,94 +0,0 @@
|
|||
"""
|
||||
Anthropic Claude适配器
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from ..base_adapter import BaseLLMAdapter
|
||||
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
|
||||
|
||||
|
||||
class ClaudeAdapter(BaseLLMAdapter):
|
||||
"""Claude适配器"""
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.CLAUDE, "https://api.anthropic.com/v1")
|
||||
|
||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
||||
try:
|
||||
await self.validate_config()
|
||||
return await self.retry(lambda: self._send_request(request))
|
||||
except Exception as error:
|
||||
self.handle_error(error, "Claude API调用失败")
|
||||
|
||||
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
||||
# Claude API需要将system消息分离
|
||||
system_message = None
|
||||
messages = []
|
||||
|
||||
for msg in request.messages:
|
||||
if msg.role == "system":
|
||||
system_message = msg.content
|
||||
else:
|
||||
messages.append({
|
||||
"role": msg.role,
|
||||
"content": msg.content
|
||||
})
|
||||
|
||||
request_body: Dict[str, Any] = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens or 4096,
|
||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
||||
}
|
||||
|
||||
if system_message:
|
||||
request_body["system"] = system_message
|
||||
|
||||
# 构建请求头
|
||||
headers = {
|
||||
"x-api-key": self.config.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
}
|
||||
|
||||
url = f"{self.base_url.rstrip('/')}/messages"
|
||||
|
||||
response = await self.client.post(
|
||||
url,
|
||||
headers=self.build_headers(headers),
|
||||
json=request_body
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_data = response.json() if response.text else {}
|
||||
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
|
||||
raise Exception(f"{error_msg}")
|
||||
|
||||
data = response.json()
|
||||
|
||||
if not data.get("content") or not data["content"][0]:
|
||||
raise Exception("API响应格式异常: 缺少content字段")
|
||||
|
||||
usage = None
|
||||
if "usage" in data:
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=data["usage"].get("input_tokens", 0),
|
||||
completion_tokens=data["usage"].get("output_tokens", 0),
|
||||
total_tokens=data["usage"].get("input_tokens", 0) + data["usage"].get("output_tokens", 0)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=data["content"][0].get("text", ""),
|
||||
model=data.get("model"),
|
||||
usage=usage,
|
||||
finish_reason=data.get("stop_reason")
|
||||
)
|
||||
|
||||
async def validate_config(self) -> bool:
|
||||
await super().validate_config()
|
||||
if not self.config.model.startswith("claude-"):
|
||||
raise Exception(f"无效的Claude模型: {self.config.model}")
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
"""
|
||||
DeepSeek适配器 - 兼容OpenAI格式
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from ..base_adapter import BaseLLMAdapter
|
||||
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
|
||||
|
||||
|
||||
class DeepSeekAdapter(BaseLLMAdapter):
|
||||
"""DeepSeek适配器"""
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.DEEPSEEK, "https://api.deepseek.com")
|
||||
|
||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
||||
try:
|
||||
await self.validate_config()
|
||||
return await self.retry(lambda: self._send_request(request))
|
||||
except Exception as error:
|
||||
self.handle_error(error, "DeepSeek API调用失败")
|
||||
|
||||
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
||||
# DeepSeek API兼容OpenAI格式
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
}
|
||||
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
|
||||
request_body: Dict[str, Any] = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
||||
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
|
||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
||||
"frequency_penalty": self.config.frequency_penalty,
|
||||
"presence_penalty": self.config.presence_penalty,
|
||||
}
|
||||
|
||||
url = f"{self.base_url.rstrip('/')}/v1/chat/completions"
|
||||
|
||||
response = await self.client.post(
|
||||
url,
|
||||
headers=self.build_headers(headers),
|
||||
json=request_body
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_data = response.json() if response.text else {}
|
||||
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
|
||||
raise Exception(f"{error_msg}")
|
||||
|
||||
data = response.json()
|
||||
choice = data.get("choices", [{}])[0]
|
||||
|
||||
if not choice:
|
||||
raise Exception("API响应格式异常: 缺少choices字段")
|
||||
|
||||
usage = None
|
||||
if "usage" in data:
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
||||
completion_tokens=data["usage"].get("completion_tokens", 0),
|
||||
total_tokens=data["usage"].get("total_tokens", 0)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=choice.get("message", {}).get("content", ""),
|
||||
model=data.get("model"),
|
||||
usage=usage,
|
||||
finish_reason=choice.get("finish_reason")
|
||||
)
|
||||
|
||||
async def validate_config(self) -> bool:
|
||||
await super().validate_config()
|
||||
if not self.config.model:
|
||||
raise Exception("未指定DeepSeek模型")
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -2,9 +2,8 @@
|
|||
字节跳动豆包适配器
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from ..base_adapter import BaseLLMAdapter
|
||||
from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError
|
||||
from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError, LLMProvider, LLMUsage
|
||||
|
||||
|
||||
class DoubaoAdapter(BaseLLMAdapter):
|
||||
|
|
@ -17,8 +16,16 @@ class DoubaoAdapter(BaseLLMAdapter):
|
|||
super().__init__(config)
|
||||
self._base_url = config.base_url or "https://ark.cn-beijing.volces.com/api/v3"
|
||||
|
||||
async def _do_complete(self, request: LLMRequest) -> LLMResponse:
|
||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
||||
"""执行实际的API调用"""
|
||||
try:
|
||||
await self.validate_config()
|
||||
return await self.retry(lambda: self._send_request(request))
|
||||
except Exception as error:
|
||||
self.handle_error(error, "豆包 API调用失败")
|
||||
|
||||
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
||||
"""发送请求"""
|
||||
url = f"{self._base_url}/chat/completions"
|
||||
|
||||
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
||||
|
|
@ -26,63 +33,52 @@ class DoubaoAdapter(BaseLLMAdapter):
|
|||
payload = {
|
||||
"model": self.config.model or "doubao-pro-32k",
|
||||
"messages": messages,
|
||||
"temperature": request.temperature or self.config.temperature,
|
||||
"top_p": request.top_p or self.config.top_p,
|
||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
||||
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
|
||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
||||
}
|
||||
|
||||
if request.max_tokens or self.config.max_tokens:
|
||||
payload["max_tokens"] = request.max_tokens or self.config.max_tokens
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.config.timeout) as client:
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
response = await self.client.post(
|
||||
url,
|
||||
headers=self.build_headers(headers),
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise LLMError(
|
||||
f"豆包API错误: {response.text}",
|
||||
provider="doubao",
|
||||
status_code=response.status_code
|
||||
)
|
||||
error_data = response.json() if response.text else {}
|
||||
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
|
||||
raise Exception(f"{error_msg}")
|
||||
|
||||
data = response.json()
|
||||
choice = data.get("choices", [{}])[0]
|
||||
|
||||
if "error" in data:
|
||||
raise LLMError(
|
||||
f"豆包API错误: {data['error'].get('message', '未知错误')}",
|
||||
provider="doubao"
|
||||
if not choice:
|
||||
raise Exception("API响应格式异常: 缺少choices字段")
|
||||
|
||||
usage = None
|
||||
if "usage" in data:
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
||||
completion_tokens=data["usage"].get("completion_tokens", 0),
|
||||
total_tokens=data["usage"].get("total_tokens", 0)
|
||||
)
|
||||
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
raise LLMError("豆包API返回空响应", provider="doubao")
|
||||
|
||||
return LLMResponse(
|
||||
content=choices[0].get("message", {}).get("content", ""),
|
||||
model=data.get("model", self.config.model or "doubao-pro-32k"),
|
||||
usage=data.get("usage"),
|
||||
finish_reason=choices[0].get("finish_reason")
|
||||
content=choice.get("message", {}).get("content", ""),
|
||||
model=data.get("model"),
|
||||
usage=usage,
|
||||
finish_reason=choice.get("finish_reason")
|
||||
)
|
||||
|
||||
async def validate_config(self) -> bool:
|
||||
"""验证配置是否有效"""
|
||||
try:
|
||||
test_request = LLMRequest(
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
max_tokens=10
|
||||
)
|
||||
await self._do_complete(test_request)
|
||||
await super().validate_config()
|
||||
if not self.config.model:
|
||||
raise LLMError("未指定豆包模型", provider=LLMProvider.DOUBAO)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_provider(self) -> str:
|
||||
return "doubao"
|
||||
|
||||
def get_model(self) -> str:
|
||||
return self.config.model or "doubao-pro-32k"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,114 +0,0 @@
|
|||
"""
|
||||
Google Gemini适配器 - 支持官方API和中转站
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List
|
||||
from ..base_adapter import BaseLLMAdapter
|
||||
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
|
||||
|
||||
|
||||
class GeminiAdapter(BaseLLMAdapter):
|
||||
"""Gemini适配器"""
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.GEMINI, "https://generativelanguage.googleapis.com/v1beta")
|
||||
|
||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
||||
try:
|
||||
await self.validate_config()
|
||||
return await self.retry(lambda: self._generate_content(request))
|
||||
except Exception as error:
|
||||
self.handle_error(error, "Gemini API调用失败")
|
||||
|
||||
async def _generate_content(self, request: LLMRequest) -> LLMResponse:
|
||||
# 转换消息格式为 Gemini 格式
|
||||
contents: List[Dict[str, Any]] = []
|
||||
system_content = ""
|
||||
|
||||
for msg in request.messages:
|
||||
if msg.role == "system":
|
||||
system_content = msg.content
|
||||
else:
|
||||
role = "model" if msg.role == "assistant" else "user"
|
||||
contents.append({
|
||||
"role": role,
|
||||
"parts": [{"text": msg.content}]
|
||||
})
|
||||
|
||||
# 将系统消息合并到第一条用户消息
|
||||
if system_content and contents:
|
||||
contents[0]["parts"][0]["text"] = f"{system_content}\n\n{contents[0]['parts'][0]['text']}"
|
||||
|
||||
# 构建请求体
|
||||
request_body = {
|
||||
"contents": contents,
|
||||
"generationConfig": {
|
||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
||||
"maxOutputTokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
|
||||
"topP": request.top_p if request.top_p is not None else self.config.top_p,
|
||||
}
|
||||
}
|
||||
|
||||
# API Key 在 URL 参数中
|
||||
url = f"{self.base_url}/models/{self.config.model}:generateContent?key={self.config.api_key}"
|
||||
|
||||
response = await self.client.post(
|
||||
url,
|
||||
headers=self.build_headers(),
|
||||
json=request_body
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_data = response.json() if response.text else {}
|
||||
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
|
||||
raise Exception(f"{error_msg}")
|
||||
|
||||
data = response.json()
|
||||
|
||||
# 解析 Gemini 响应格式
|
||||
candidates = data.get("candidates", [])
|
||||
if not candidates:
|
||||
# 检查是否有错误信息
|
||||
if "error" in data:
|
||||
error_msg = data["error"].get("message", "未知错误")
|
||||
raise Exception(f"Gemini API错误: {error_msg}")
|
||||
raise Exception("API响应格式异常: 缺少candidates字段")
|
||||
|
||||
candidate = candidates[0]
|
||||
if not candidate or "content" not in candidate:
|
||||
raise Exception("API响应格式异常: 缺少content字段")
|
||||
|
||||
text_parts = candidate.get("content", {}).get("parts", [])
|
||||
if not text_parts:
|
||||
raise Exception("API响应格式异常: content.parts为空")
|
||||
|
||||
text = "".join(part.get("text", "") for part in text_parts)
|
||||
|
||||
# 检查响应内容是否为空
|
||||
if not text or not text.strip():
|
||||
finish_reason = candidate.get("finishReason", "unknown")
|
||||
raise Exception(f"Gemini返回空响应 - Finish Reason: {finish_reason}")
|
||||
|
||||
usage = None
|
||||
if "usageMetadata" in data:
|
||||
usage_data = data["usageMetadata"]
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=usage_data.get("promptTokenCount", 0),
|
||||
completion_tokens=usage_data.get("candidatesTokenCount", 0),
|
||||
total_tokens=usage_data.get("totalTokenCount", 0)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=text,
|
||||
model=self.config.model,
|
||||
usage=usage,
|
||||
finish_reason=candidate.get("finishReason", "stop")
|
||||
)
|
||||
|
||||
async def validate_config(self) -> bool:
|
||||
await super().validate_config()
|
||||
if not self.config.model.startswith("gemini-"):
|
||||
raise Exception(f"无效的Gemini模型: {self.config.model}")
|
||||
return True
|
||||
|
||||
|
|
@ -0,0 +1,182 @@
|
|||
"""
|
||||
LiteLLM 统一适配器
|
||||
支持通过 LiteLLM 调用多个 LLM 提供商,使用统一的 OpenAI 兼容格式
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from ..base_adapter import BaseLLMAdapter
|
||||
from ..types import (
|
||||
LLMConfig,
|
||||
LLMRequest,
|
||||
LLMResponse,
|
||||
LLMUsage,
|
||||
LLMProvider,
|
||||
LLMError,
|
||||
DEFAULT_BASE_URLS,
|
||||
)
|
||||
|
||||
|
||||
class LiteLLMAdapter(BaseLLMAdapter):
|
||||
"""
|
||||
LiteLLM 统一适配器
|
||||
|
||||
支持的提供商:
|
||||
- OpenAI (openai/gpt-4o-mini)
|
||||
- Claude (anthropic/claude-3-5-sonnet-20241022)
|
||||
- Gemini (gemini/gemini-1.5-flash)
|
||||
- DeepSeek (deepseek/deepseek-chat)
|
||||
- Qwen (qwen/qwen-turbo) - 通过 OpenAI 兼容模式
|
||||
- Zhipu (zhipu/glm-4-flash) - 通过 OpenAI 兼容模式
|
||||
- Moonshot (moonshot/moonshot-v1-8k) - 通过 OpenAI 兼容模式
|
||||
- Ollama (ollama/llama3)
|
||||
"""
|
||||
|
||||
# LiteLLM 模型前缀映射
|
||||
PROVIDER_PREFIX_MAP = {
|
||||
LLMProvider.OPENAI: "openai",
|
||||
LLMProvider.CLAUDE: "anthropic",
|
||||
LLMProvider.GEMINI: "gemini",
|
||||
LLMProvider.DEEPSEEK: "deepseek",
|
||||
LLMProvider.QWEN: "openai", # 使用 OpenAI 兼容模式
|
||||
LLMProvider.ZHIPU: "openai", # 使用 OpenAI 兼容模式
|
||||
LLMProvider.MOONSHOT: "openai", # 使用 OpenAI 兼容模式
|
||||
LLMProvider.OLLAMA: "ollama",
|
||||
}
|
||||
|
||||
# 需要自定义 base_url 的提供商
|
||||
CUSTOM_BASE_URL_PROVIDERS = {
|
||||
LLMProvider.QWEN,
|
||||
LLMProvider.ZHIPU,
|
||||
LLMProvider.MOONSHOT,
|
||||
LLMProvider.DEEPSEEK,
|
||||
}
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config)
|
||||
self._litellm_model = self._get_litellm_model()
|
||||
self._api_base = self._get_api_base()
|
||||
|
||||
def _get_litellm_model(self) -> str:
|
||||
"""获取 LiteLLM 格式的模型名称"""
|
||||
provider = self.config.provider
|
||||
model = self.config.model
|
||||
|
||||
# 对于使用 OpenAI 兼容模式的提供商,直接使用模型名
|
||||
if provider in self.CUSTOM_BASE_URL_PROVIDERS:
|
||||
return model
|
||||
|
||||
# 对于原生支持的提供商,添加前缀
|
||||
prefix = self.PROVIDER_PREFIX_MAP.get(provider, "openai")
|
||||
|
||||
# 检查模型名是否已经包含前缀
|
||||
if "/" in model:
|
||||
return model
|
||||
|
||||
return f"{prefix}/{model}"
|
||||
|
||||
def _get_api_base(self) -> Optional[str]:
|
||||
"""获取 API 基础 URL"""
|
||||
# 优先使用用户配置的 base_url
|
||||
if self.config.base_url:
|
||||
return self.config.base_url
|
||||
|
||||
# 对于需要自定义 base_url 的提供商,使用默认值
|
||||
if self.config.provider in self.CUSTOM_BASE_URL_PROVIDERS:
|
||||
return DEFAULT_BASE_URLS.get(self.config.provider)
|
||||
|
||||
# Ollama 使用本地地址
|
||||
if self.config.provider == LLMProvider.OLLAMA:
|
||||
return DEFAULT_BASE_URLS.get(LLMProvider.OLLAMA, "http://localhost:11434")
|
||||
|
||||
return None
|
||||
|
||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
||||
"""使用 LiteLLM 发送请求"""
|
||||
try:
|
||||
await self.validate_config()
|
||||
return await self.retry(lambda: self._send_request(request))
|
||||
except Exception as error:
|
||||
self.handle_error(error, f"LiteLLM ({self.config.provider.value}) API调用失败")
|
||||
|
||||
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
||||
"""发送请求到 LiteLLM"""
|
||||
import litellm
|
||||
|
||||
# 构建消息
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
|
||||
# 构建请求参数
|
||||
kwargs: Dict[str, Any] = {
|
||||
"model": self._litellm_model,
|
||||
"messages": messages,
|
||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
||||
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
|
||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
||||
}
|
||||
|
||||
# 设置 API Key
|
||||
if self.config.api_key and self.config.api_key != "ollama":
|
||||
kwargs["api_key"] = self.config.api_key
|
||||
|
||||
# 设置 API Base URL
|
||||
if self._api_base:
|
||||
kwargs["api_base"] = self._api_base
|
||||
|
||||
# 设置超时
|
||||
kwargs["timeout"] = self.config.timeout
|
||||
|
||||
# 对于 OpenAI 提供商,添加额外参数
|
||||
if self.config.provider == LLMProvider.OPENAI:
|
||||
kwargs["frequency_penalty"] = self.config.frequency_penalty
|
||||
kwargs["presence_penalty"] = self.config.presence_penalty
|
||||
|
||||
# 调用 LiteLLM
|
||||
response = await litellm.acompletion(**kwargs)
|
||||
|
||||
# 解析响应
|
||||
choice = response.choices[0] if response.choices else None
|
||||
if not choice:
|
||||
raise LLMError("API响应格式异常: 缺少choices字段", self.config.provider)
|
||||
|
||||
usage = None
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens or 0,
|
||||
completion_tokens=response.usage.completion_tokens or 0,
|
||||
total_tokens=response.usage.total_tokens or 0,
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=choice.message.content or "",
|
||||
model=response.model,
|
||||
usage=usage,
|
||||
finish_reason=choice.finish_reason,
|
||||
)
|
||||
|
||||
async def validate_config(self) -> bool:
|
||||
"""验证配置"""
|
||||
# Ollama 不需要 API Key
|
||||
if self.config.provider == LLMProvider.OLLAMA:
|
||||
if not self.config.model:
|
||||
raise LLMError("未指定 Ollama 模型", LLMProvider.OLLAMA)
|
||||
return True
|
||||
|
||||
# 其他提供商需要 API Key
|
||||
if not self.config.api_key:
|
||||
raise LLMError(
|
||||
f"API Key未配置 ({self.config.provider.value})",
|
||||
self.config.provider,
|
||||
)
|
||||
|
||||
if not self.config.model:
|
||||
raise LLMError(
|
||||
f"未指定模型 ({self.config.provider.value})",
|
||||
self.config.provider,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_provider(cls, provider: LLMProvider) -> bool:
|
||||
"""检查是否支持指定的提供商"""
|
||||
return provider in cls.PROVIDER_PREFIX_MAP
|
||||
|
|
@ -2,9 +2,8 @@
|
|||
MiniMax适配器
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from ..base_adapter import BaseLLMAdapter
|
||||
from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError
|
||||
from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError, LLMProvider, LLMUsage
|
||||
|
||||
|
||||
class MinimaxAdapter(BaseLLMAdapter):
|
||||
|
|
@ -14,8 +13,16 @@ class MinimaxAdapter(BaseLLMAdapter):
|
|||
super().__init__(config)
|
||||
self._base_url = config.base_url or "https://api.minimax.chat/v1"
|
||||
|
||||
async def _do_complete(self, request: LLMRequest) -> LLMResponse:
|
||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
||||
"""执行实际的API调用"""
|
||||
try:
|
||||
await self.validate_config()
|
||||
return await self.retry(lambda: self._send_request(request))
|
||||
except Exception as error:
|
||||
self.handle_error(error, "MiniMax API调用失败")
|
||||
|
||||
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
||||
"""发送请求"""
|
||||
url = f"{self._base_url}/text/chatcompletion_v2"
|
||||
|
||||
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
||||
|
|
@ -23,63 +30,58 @@ class MinimaxAdapter(BaseLLMAdapter):
|
|||
payload = {
|
||||
"model": self.config.model or "abab6.5-chat",
|
||||
"messages": messages,
|
||||
"temperature": request.temperature or self.config.temperature,
|
||||
"top_p": request.top_p or self.config.top_p,
|
||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
||||
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
|
||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
||||
}
|
||||
|
||||
if request.max_tokens or self.config.max_tokens:
|
||||
payload["max_tokens"] = request.max_tokens or self.config.max_tokens
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.config.timeout) as client:
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
response = await self.client.post(
|
||||
url,
|
||||
headers=self.build_headers(headers),
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise LLMError(
|
||||
f"MiniMax API错误: {response.text}",
|
||||
provider="minimax",
|
||||
status_code=response.status_code
|
||||
)
|
||||
error_data = response.json() if response.text else {}
|
||||
error_msg = error_data.get("base_resp", {}).get("status_msg", f"HTTP {response.status_code}")
|
||||
raise Exception(f"{error_msg}")
|
||||
|
||||
data = response.json()
|
||||
|
||||
# MiniMax 特殊的错误处理
|
||||
if data.get("base_resp", {}).get("status_code") != 0:
|
||||
raise LLMError(
|
||||
f"MiniMax API错误: {data.get('base_resp', {}).get('status_msg', '未知错误')}",
|
||||
provider="minimax"
|
||||
error_msg = data.get("base_resp", {}).get("status_msg", "未知错误")
|
||||
raise Exception(f"MiniMax API错误: {error_msg}")
|
||||
|
||||
choice = data.get("choices", [{}])[0]
|
||||
|
||||
if not choice:
|
||||
raise Exception("API响应格式异常: 缺少choices字段")
|
||||
|
||||
usage = None
|
||||
if "usage" in data:
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
||||
completion_tokens=data["usage"].get("completion_tokens", 0),
|
||||
total_tokens=data["usage"].get("total_tokens", 0)
|
||||
)
|
||||
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
raise LLMError("MiniMax API返回空响应", provider="minimax")
|
||||
|
||||
return LLMResponse(
|
||||
content=choices[0].get("message", {}).get("content", ""),
|
||||
model=self.config.model or "abab6.5-chat",
|
||||
usage=data.get("usage"),
|
||||
finish_reason=choices[0].get("finish_reason")
|
||||
content=choice.get("message", {}).get("content", ""),
|
||||
model=data.get("model"),
|
||||
usage=usage,
|
||||
finish_reason=choice.get("finish_reason")
|
||||
)
|
||||
|
||||
async def validate_config(self) -> bool:
|
||||
"""验证配置是否有效"""
|
||||
try:
|
||||
test_request = LLMRequest(
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
max_tokens=10
|
||||
)
|
||||
await self._do_complete(test_request)
|
||||
await super().validate_config()
|
||||
if not self.config.model:
|
||||
raise LLMError("未指定MiniMax模型", provider=LLMProvider.MINIMAX)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_provider(self) -> str:
|
||||
return "minimax"
|
||||
|
||||
def get_model(self) -> str:
|
||||
return self.config.model or "abab6.5-chat"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,80 +0,0 @@
|
|||
"""
|
||||
月之暗面 Kimi适配器 - 兼容OpenAI格式
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from ..base_adapter import BaseLLMAdapter
|
||||
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
|
||||
|
||||
|
||||
class MoonshotAdapter(BaseLLMAdapter):
|
||||
"""月之暗面Kimi适配器"""
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.MOONSHOT, "https://api.moonshot.cn/v1")
|
||||
|
||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
||||
try:
|
||||
await self.validate_config()
|
||||
return await self.retry(lambda: self._send_request(request))
|
||||
except Exception as error:
|
||||
self.handle_error(error, "Moonshot API调用失败")
|
||||
|
||||
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
||||
# Moonshot API兼容OpenAI格式
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
}
|
||||
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
|
||||
request_body: Dict[str, Any] = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
||||
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
|
||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
||||
}
|
||||
|
||||
url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
|
||||
response = await self.client.post(
|
||||
url,
|
||||
headers=self.build_headers(headers),
|
||||
json=request_body
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_data = response.json() if response.text else {}
|
||||
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
|
||||
raise Exception(f"{error_msg}")
|
||||
|
||||
data = response.json()
|
||||
choice = data.get("choices", [{}])[0]
|
||||
|
||||
if not choice:
|
||||
raise Exception("API响应格式异常: 缺少choices字段")
|
||||
|
||||
usage = None
|
||||
if "usage" in data:
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
||||
completion_tokens=data["usage"].get("completion_tokens", 0),
|
||||
total_tokens=data["usage"].get("total_tokens", 0)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=choice.get("message", {}).get("content", ""),
|
||||
model=data.get("model"),
|
||||
usage=usage,
|
||||
finish_reason=choice.get("finish_reason")
|
||||
)
|
||||
|
||||
async def validate_config(self) -> bool:
|
||||
await super().validate_config()
|
||||
if not self.config.model:
|
||||
raise Exception("未指定Moonshot模型")
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -1,83 +0,0 @@
|
|||
"""
|
||||
Ollama本地大模型适配器 - 兼容OpenAI格式
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from ..base_adapter import BaseLLMAdapter
|
||||
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
|
||||
|
||||
|
||||
class OllamaAdapter(BaseLLMAdapter):
|
||||
"""Ollama本地模型适配器"""
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.OLLAMA, "http://localhost:11434/v1")
|
||||
|
||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
||||
try:
|
||||
# Ollama本地运行,跳过API Key验证
|
||||
return await self.retry(lambda: self._send_request(request))
|
||||
except Exception as error:
|
||||
self.handle_error(error, "Ollama API调用失败")
|
||||
|
||||
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
||||
# Ollama兼容OpenAI格式
|
||||
headers = {}
|
||||
if self.config.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.config.api_key}"
|
||||
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
|
||||
request_body: Dict[str, Any] = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
||||
}
|
||||
|
||||
# Ollama的max_tokens参数名可能不同
|
||||
if request.max_tokens or self.config.max_tokens:
|
||||
request_body["num_predict"] = request.max_tokens or self.config.max_tokens
|
||||
|
||||
url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
|
||||
response = await self.client.post(
|
||||
url,
|
||||
headers=self.build_headers(headers) if headers else self.build_headers(),
|
||||
json=request_body
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_data = response.json() if response.text else {}
|
||||
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
|
||||
raise Exception(f"{error_msg}")
|
||||
|
||||
data = response.json()
|
||||
choice = data.get("choices", [{}])[0]
|
||||
|
||||
if not choice:
|
||||
raise Exception("API响应格式异常: 缺少choices字段")
|
||||
|
||||
usage = None
|
||||
if "usage" in data:
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
||||
completion_tokens=data["usage"].get("completion_tokens", 0),
|
||||
total_tokens=data["usage"].get("total_tokens", 0)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=choice.get("message", {}).get("content", ""),
|
||||
model=data.get("model"),
|
||||
usage=usage,
|
||||
finish_reason=choice.get("finish_reason")
|
||||
)
|
||||
|
||||
async def validate_config(self) -> bool:
|
||||
# Ollama本地运行,不需要API Key
|
||||
if not self.config.model:
|
||||
raise Exception("未指定Ollama模型")
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -1,93 +0,0 @@
|
|||
"""
|
||||
OpenAI适配器 (支持GPT系列和OpenAI兼容API)
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from ..base_adapter import BaseLLMAdapter
|
||||
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
|
||||
|
||||
|
||||
class OpenAIAdapter(BaseLLMAdapter):
|
||||
"""OpenAI适配器"""
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.OPENAI, "https://api.openai.com/v1")
|
||||
|
||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
||||
try:
|
||||
await self.validate_config()
|
||||
return await self.retry(lambda: self._send_request(request))
|
||||
except Exception as error:
|
||||
self.handle_error(error, "OpenAI API调用失败")
|
||||
|
||||
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
||||
# 构建请求头
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
}
|
||||
|
||||
# 检测是否为推理模型(o1/o3系列)
|
||||
model_name = self.config.model.lower()
|
||||
is_reasoning_model = "o1" in model_name or "o3" in model_name
|
||||
|
||||
# 构建请求体
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
|
||||
request_body: Dict[str, Any] = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
||||
"frequency_penalty": self.config.frequency_penalty,
|
||||
"presence_penalty": self.config.presence_penalty,
|
||||
}
|
||||
|
||||
# 推理模型使用max_completion_tokens,其他模型使用max_tokens
|
||||
max_tokens = request.max_tokens if request.max_tokens is not None else self.config.max_tokens
|
||||
if is_reasoning_model:
|
||||
request_body["max_completion_tokens"] = max_tokens
|
||||
else:
|
||||
request_body["max_tokens"] = max_tokens
|
||||
|
||||
url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
|
||||
response = await self.client.post(
|
||||
url,
|
||||
headers=self.build_headers(headers),
|
||||
json=request_body
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_data = response.json() if response.text else {}
|
||||
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
|
||||
raise Exception(f"{error_msg}")
|
||||
|
||||
data = response.json()
|
||||
choice = data.get("choices", [{}])[0]
|
||||
|
||||
if not choice:
|
||||
raise Exception("API响应格式异常: 缺少choices字段")
|
||||
|
||||
usage = None
|
||||
if "usage" in data:
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
||||
completion_tokens=data["usage"].get("completion_tokens", 0),
|
||||
total_tokens=data["usage"].get("total_tokens", 0)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=choice.get("message", {}).get("content", ""),
|
||||
model=data.get("model"),
|
||||
usage=usage,
|
||||
finish_reason=choice.get("finish_reason")
|
||||
)
|
||||
|
||||
async def validate_config(self) -> bool:
|
||||
await super().validate_config()
|
||||
if not self.config.model:
|
||||
raise Exception("未指定OpenAI模型")
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -1,80 +0,0 @@
|
|||
"""
|
||||
阿里云通义千问适配器 - 兼容OpenAI格式
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from ..base_adapter import BaseLLMAdapter
|
||||
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
|
||||
|
||||
|
||||
class QwenAdapter(BaseLLMAdapter):
|
||||
"""通义千问适配器"""
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.QWEN, "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
|
||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
||||
try:
|
||||
await self.validate_config()
|
||||
return await self.retry(lambda: self._send_request(request))
|
||||
except Exception as error:
|
||||
self.handle_error(error, "通义千问 API调用失败")
|
||||
|
||||
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
||||
# 通义千问兼容OpenAI格式
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
}
|
||||
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
|
||||
request_body: Dict[str, Any] = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
||||
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
|
||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
||||
}
|
||||
|
||||
url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
|
||||
response = await self.client.post(
|
||||
url,
|
||||
headers=self.build_headers(headers),
|
||||
json=request_body
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_data = response.json() if response.text else {}
|
||||
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
|
||||
raise Exception(f"{error_msg}")
|
||||
|
||||
data = response.json()
|
||||
choice = data.get("choices", [{}])[0]
|
||||
|
||||
if not choice:
|
||||
raise Exception("API响应格式异常: 缺少choices字段")
|
||||
|
||||
usage = None
|
||||
if "usage" in data:
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
||||
completion_tokens=data["usage"].get("completion_tokens", 0),
|
||||
total_tokens=data["usage"].get("total_tokens", 0)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=choice.get("message", {}).get("content", ""),
|
||||
model=data.get("model"),
|
||||
usage=usage,
|
||||
finish_reason=choice.get("finish_reason")
|
||||
)
|
||||
|
||||
async def validate_config(self) -> bool:
|
||||
await super().validate_config()
|
||||
if not self.config.model:
|
||||
raise Exception("未指定通义千问模型")
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -1,80 +0,0 @@
|
|||
"""
|
||||
智谱AI适配器 (GLM系列) - 兼容OpenAI格式
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from ..base_adapter import BaseLLMAdapter
|
||||
from ..types import LLMRequest, LLMResponse, LLMUsage, DEFAULT_BASE_URLS, LLMProvider
|
||||
|
||||
|
||||
class ZhipuAdapter(BaseLLMAdapter):
|
||||
"""智谱AI适配器"""
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return self.config.base_url or DEFAULT_BASE_URLS.get(LLMProvider.ZHIPU, "https://open.bigmodel.cn/api/paas/v4")
|
||||
|
||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
||||
try:
|
||||
await self.validate_config()
|
||||
return await self.retry(lambda: self._send_request(request))
|
||||
except Exception as error:
|
||||
self.handle_error(error, "智谱AI API调用失败")
|
||||
|
||||
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
||||
# 智谱AI兼容OpenAI格式
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
}
|
||||
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
|
||||
request_body: Dict[str, Any] = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
||||
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
|
||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
||||
}
|
||||
|
||||
url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
|
||||
response = await self.client.post(
|
||||
url,
|
||||
headers=self.build_headers(headers),
|
||||
json=request_body
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_data = response.json() if response.text else {}
|
||||
error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status_code}")
|
||||
raise Exception(f"{error_msg}")
|
||||
|
||||
data = response.json()
|
||||
choice = data.get("choices", [{}])[0]
|
||||
|
||||
if not choice:
|
||||
raise Exception("API响应格式异常: 缺少choices字段")
|
||||
|
||||
usage = None
|
||||
if "usage" in data:
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
||||
completion_tokens=data["usage"].get("completion_tokens", 0),
|
||||
total_tokens=data["usage"].get("total_tokens", 0)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=choice.get("message", {}).get("content", ""),
|
||||
model=data.get("model"),
|
||||
usage=usage,
|
||||
finish_reason=choice.get("finish_reason")
|
||||
)
|
||||
|
||||
async def validate_config(self) -> bool:
|
||||
await super().validate_config()
|
||||
if not self.config.model:
|
||||
raise Exception("未指定智谱AI模型")
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -1,25 +1,29 @@
|
|||
"""
|
||||
LLM工厂类 - 统一创建和管理LLM适配器
|
||||
|
||||
使用 LiteLLM 作为主要适配器,支持大多数 LLM 提供商。
|
||||
对于 API 格式特殊的提供商(百度、MiniMax、豆包),使用原生适配器。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List
|
||||
from .types import LLMConfig, LLMProvider, DEFAULT_MODELS
|
||||
from .base_adapter import BaseLLMAdapter
|
||||
from .adapters import (
|
||||
OpenAIAdapter,
|
||||
GeminiAdapter,
|
||||
ClaudeAdapter,
|
||||
DeepSeekAdapter,
|
||||
QwenAdapter,
|
||||
ZhipuAdapter,
|
||||
MoonshotAdapter,
|
||||
LiteLLMAdapter,
|
||||
BaiduAdapter,
|
||||
MinimaxAdapter,
|
||||
DoubaoAdapter,
|
||||
OllamaAdapter,
|
||||
)
|
||||
|
||||
|
||||
# 必须使用原生适配器的提供商(API 格式特殊)
|
||||
NATIVE_ONLY_PROVIDERS = {
|
||||
LLMProvider.BAIDU,
|
||||
LLMProvider.MINIMAX,
|
||||
LLMProvider.DOUBAO,
|
||||
}
|
||||
|
||||
|
||||
class LLMFactory:
|
||||
"""LLM工厂类"""
|
||||
|
||||
|
|
@ -49,23 +53,29 @@ class LLMFactory:
|
|||
if not config.model:
|
||||
config.model = DEFAULT_MODELS.get(config.provider, "gpt-4o-mini")
|
||||
|
||||
adapter_map = {
|
||||
LLMProvider.OPENAI: OpenAIAdapter,
|
||||
LLMProvider.GEMINI: GeminiAdapter,
|
||||
LLMProvider.CLAUDE: ClaudeAdapter,
|
||||
LLMProvider.DEEPSEEK: DeepSeekAdapter,
|
||||
LLMProvider.QWEN: QwenAdapter,
|
||||
LLMProvider.ZHIPU: ZhipuAdapter,
|
||||
LLMProvider.MOONSHOT: MoonshotAdapter,
|
||||
# 对于必须使用原生适配器的提供商
|
||||
if config.provider in NATIVE_ONLY_PROVIDERS:
|
||||
return cls._create_native_adapter(config)
|
||||
|
||||
# 其他提供商使用 LiteLLM
|
||||
if LiteLLMAdapter.supports_provider(config.provider):
|
||||
return LiteLLMAdapter(config)
|
||||
|
||||
# 不支持的提供商
|
||||
raise ValueError(f"不支持的LLM提供商: {config.provider}")
|
||||
|
||||
@classmethod
|
||||
def _create_native_adapter(cls, config: LLMConfig) -> BaseLLMAdapter:
|
||||
"""创建原生适配器(仅用于 API 格式特殊的提供商)"""
|
||||
native_adapter_map = {
|
||||
LLMProvider.BAIDU: BaiduAdapter,
|
||||
LLMProvider.MINIMAX: MinimaxAdapter,
|
||||
LLMProvider.DOUBAO: DoubaoAdapter,
|
||||
LLMProvider.OLLAMA: OllamaAdapter,
|
||||
}
|
||||
|
||||
adapter_class = adapter_map.get(config.provider)
|
||||
adapter_class = native_adapter_map.get(config.provider)
|
||||
if not adapter_class:
|
||||
raise ValueError(f"不支持的LLM提供商: {config.provider}")
|
||||
raise ValueError(f"不支持的原生适配器提供商: {config.provider}")
|
||||
|
||||
return adapter_class(config)
|
||||
|
||||
|
|
@ -162,4 +172,3 @@ class LLMFactory:
|
|||
],
|
||||
}
|
||||
return models.get(provider, [])
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,9 @@ ACCESS_TOKEN_EXPIRE_MINUTES=11520
|
|||
|
||||
# ------------ LLM配置 ------------
|
||||
# 支持的provider: openai, gemini, claude, qwen, deepseek, zhipu, moonshot, baidu, minimax, doubao, ollama
|
||||
#
|
||||
# 使用 LiteLLM 统一适配器: openai, gemini, claude, qwen, deepseek, zhipu, moonshot, ollama
|
||||
# 使用原生适配器 (API格式特殊): baidu, minimax, doubao
|
||||
LLM_PROVIDER=openai
|
||||
LLM_API_KEY=sk-your-api-key
|
||||
LLM_MODEL=
|
||||
|
|
|
|||
|
|
@ -18,4 +18,5 @@ dependencies = [
|
|||
"email-validator",
|
||||
"greenlet",
|
||||
"bcrypt<5.0.0",
|
||||
"litellm>=1.0.0",
|
||||
]
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -273,6 +273,34 @@ export const api = {
|
|||
await apiClient.delete('/config/me');
|
||||
},
|
||||
|
||||
async testLLMConnection(params: {
|
||||
provider: string;
|
||||
apiKey: string;
|
||||
model?: string;
|
||||
baseUrl?: string;
|
||||
}): Promise<{
|
||||
success: boolean;
|
||||
message: string;
|
||||
model?: string;
|
||||
response?: string;
|
||||
}> {
|
||||
const res = await apiClient.post('/config/test-llm', params);
|
||||
return res.data;
|
||||
},
|
||||
|
||||
async getLLMProviders(): Promise<{
|
||||
providers: Array<{
|
||||
id: string;
|
||||
name: string;
|
||||
defaultModel: string;
|
||||
models: string[];
|
||||
defaultBaseUrl: string;
|
||||
}>;
|
||||
}> {
|
||||
const res = await apiClient.get('/config/llm-providers');
|
||||
return res.data;
|
||||
},
|
||||
|
||||
// ==================== 数据库管理相关方法 ====================
|
||||
|
||||
async exportDatabase(): Promise<{
|
||||
|
|
|
|||
Loading…
Reference in New Issue