206 lines
7.4 KiB
Python
206 lines
7.4 KiB
Python
"""
|
|
LiteLLM 统一适配器
|
|
支持通过 LiteLLM 调用多个 LLM 提供商,使用统一的 OpenAI 兼容格式
|
|
"""
|
|
|
|
from typing import Dict, Any, Optional
|
|
from ..base_adapter import BaseLLMAdapter
|
|
from ..types import (
|
|
LLMConfig,
|
|
LLMRequest,
|
|
LLMResponse,
|
|
LLMUsage,
|
|
LLMProvider,
|
|
LLMError,
|
|
DEFAULT_BASE_URLS,
|
|
)
|
|
|
|
|
|
class LiteLLMAdapter(BaseLLMAdapter):
|
|
"""
|
|
LiteLLM 统一适配器
|
|
|
|
支持的提供商:
|
|
- OpenAI (openai/gpt-4o-mini)
|
|
- Claude (anthropic/claude-3-5-sonnet-20241022)
|
|
- Gemini (gemini/gemini-1.5-flash)
|
|
- DeepSeek (deepseek/deepseek-chat)
|
|
- Qwen (qwen/qwen-turbo) - 通过 OpenAI 兼容模式
|
|
- Zhipu (zhipu/glm-4-flash) - 通过 OpenAI 兼容模式
|
|
- Moonshot (moonshot/moonshot-v1-8k) - 通过 OpenAI 兼容模式
|
|
- Ollama (ollama/llama3)
|
|
"""
|
|
|
|
# LiteLLM 模型前缀映射
|
|
PROVIDER_PREFIX_MAP = {
|
|
LLMProvider.OPENAI: "openai",
|
|
LLMProvider.CLAUDE: "anthropic",
|
|
LLMProvider.GEMINI: "gemini",
|
|
LLMProvider.DEEPSEEK: "deepseek",
|
|
LLMProvider.QWEN: "openai", # 使用 OpenAI 兼容模式
|
|
LLMProvider.ZHIPU: "openai", # 使用 OpenAI 兼容模式
|
|
LLMProvider.MOONSHOT: "openai", # 使用 OpenAI 兼容模式
|
|
LLMProvider.OLLAMA: "ollama",
|
|
}
|
|
|
|
# 需要自定义 base_url 的提供商
|
|
CUSTOM_BASE_URL_PROVIDERS = {
|
|
LLMProvider.QWEN,
|
|
LLMProvider.ZHIPU,
|
|
LLMProvider.MOONSHOT,
|
|
LLMProvider.DEEPSEEK,
|
|
}
|
|
|
|
def __init__(self, config: LLMConfig):
|
|
super().__init__(config)
|
|
self._litellm_model = self._get_litellm_model()
|
|
self._api_base = self._get_api_base()
|
|
|
|
def _get_litellm_model(self) -> str:
|
|
"""获取 LiteLLM 格式的模型名称"""
|
|
provider = self.config.provider
|
|
model = self.config.model
|
|
|
|
# 对于使用 OpenAI 兼容模式的提供商,直接使用模型名
|
|
if provider in self.CUSTOM_BASE_URL_PROVIDERS:
|
|
return model
|
|
|
|
# 对于原生支持的提供商,添加前缀
|
|
prefix = self.PROVIDER_PREFIX_MAP.get(provider, "openai")
|
|
|
|
# 检查模型名是否已经包含前缀
|
|
if "/" in model:
|
|
return model
|
|
|
|
return f"{prefix}/{model}"
|
|
|
|
def _get_api_base(self) -> Optional[str]:
|
|
"""获取 API 基础 URL"""
|
|
# 优先使用用户配置的 base_url
|
|
if self.config.base_url:
|
|
return self.config.base_url
|
|
|
|
# 对于需要自定义 base_url 的提供商,使用默认值
|
|
if self.config.provider in self.CUSTOM_BASE_URL_PROVIDERS:
|
|
return DEFAULT_BASE_URLS.get(self.config.provider)
|
|
|
|
# Ollama 使用本地地址
|
|
if self.config.provider == LLMProvider.OLLAMA:
|
|
return DEFAULT_BASE_URLS.get(LLMProvider.OLLAMA, "http://localhost:11434")
|
|
|
|
return None
|
|
|
|
async def complete(self, request: LLMRequest) -> LLMResponse:
|
|
"""使用 LiteLLM 发送请求"""
|
|
try:
|
|
await self.validate_config()
|
|
return await self.retry(lambda: self._send_request(request))
|
|
except Exception as error:
|
|
self.handle_error(error, f"LiteLLM ({self.config.provider.value}) API调用失败")
|
|
|
|
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
|
"""发送请求到 LiteLLM"""
|
|
import litellm
|
|
|
|
# 禁用 LiteLLM 的缓存,确保每次都实际调用 API
|
|
litellm.cache = None
|
|
|
|
# 构建消息
|
|
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
|
|
|
# 构建请求参数
|
|
kwargs: Dict[str, Any] = {
|
|
"model": self._litellm_model,
|
|
"messages": messages,
|
|
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
|
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
|
|
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
|
}
|
|
|
|
# 设置 API Key
|
|
if self.config.api_key and self.config.api_key != "ollama":
|
|
kwargs["api_key"] = self.config.api_key
|
|
|
|
# 设置 API Base URL
|
|
if self._api_base:
|
|
kwargs["api_base"] = self._api_base
|
|
|
|
# 设置超时
|
|
kwargs["timeout"] = self.config.timeout
|
|
|
|
# 对于 OpenAI 提供商,添加额外参数
|
|
if self.config.provider == LLMProvider.OPENAI:
|
|
kwargs["frequency_penalty"] = self.config.frequency_penalty
|
|
kwargs["presence_penalty"] = self.config.presence_penalty
|
|
|
|
try:
|
|
# 调用 LiteLLM
|
|
response = await litellm.acompletion(**kwargs)
|
|
except litellm.exceptions.AuthenticationError as e:
|
|
raise LLMError(f"API Key 无效或已过期: {str(e)}", self.config.provider, 401)
|
|
except litellm.exceptions.RateLimitError as e:
|
|
raise LLMError(f"API 调用频率超限: {str(e)}", self.config.provider, 429)
|
|
except litellm.exceptions.APIConnectionError as e:
|
|
raise LLMError(f"无法连接到 API 服务: {str(e)}", self.config.provider)
|
|
except litellm.exceptions.APIError as e:
|
|
raise LLMError(f"API 错误: {str(e)}", self.config.provider, getattr(e, 'status_code', None))
|
|
except Exception as e:
|
|
# 捕获其他异常并重新抛出
|
|
error_msg = str(e)
|
|
if "invalid_api_key" in error_msg.lower() or "incorrect api key" in error_msg.lower():
|
|
raise LLMError(f"API Key 无效: {error_msg}", self.config.provider, 401)
|
|
elif "authentication" in error_msg.lower():
|
|
raise LLMError(f"认证失败: {error_msg}", self.config.provider, 401)
|
|
raise
|
|
|
|
# 解析响应
|
|
if not response:
|
|
raise LLMError("API 返回空响应", self.config.provider)
|
|
|
|
choice = response.choices[0] if response.choices else None
|
|
if not choice:
|
|
raise LLMError("API响应格式异常: 缺少choices字段", self.config.provider)
|
|
|
|
usage = None
|
|
if hasattr(response, "usage") and response.usage:
|
|
usage = LLMUsage(
|
|
prompt_tokens=response.usage.prompt_tokens or 0,
|
|
completion_tokens=response.usage.completion_tokens or 0,
|
|
total_tokens=response.usage.total_tokens or 0,
|
|
)
|
|
|
|
return LLMResponse(
|
|
content=choice.message.content or "",
|
|
model=response.model,
|
|
usage=usage,
|
|
finish_reason=choice.finish_reason,
|
|
)
|
|
|
|
async def validate_config(self) -> bool:
|
|
"""验证配置"""
|
|
# Ollama 不需要 API Key
|
|
if self.config.provider == LLMProvider.OLLAMA:
|
|
if not self.config.model:
|
|
raise LLMError("未指定 Ollama 模型", LLMProvider.OLLAMA)
|
|
return True
|
|
|
|
# 其他提供商需要 API Key
|
|
if not self.config.api_key:
|
|
raise LLMError(
|
|
f"API Key未配置 ({self.config.provider.value})",
|
|
self.config.provider,
|
|
)
|
|
|
|
if not self.config.model:
|
|
raise LLMError(
|
|
f"未指定模型 ({self.config.provider.value})",
|
|
self.config.provider,
|
|
)
|
|
|
|
return True
|
|
|
|
@classmethod
|
|
def supports_provider(cls, provider: LLMProvider) -> bool:
|
|
"""检查是否支持指定的提供商"""
|
|
return provider in cls.PROVIDER_PREFIX_MAP
|