CodeReview/backend/app/services/llm/adapters/litellm_adapter.py

206 lines
7.4 KiB
Python

"""
LiteLLM 统一适配器
支持通过 LiteLLM 调用多个 LLM 提供商,使用统一的 OpenAI 兼容格式
"""
from typing import Dict, Any, Optional
from ..base_adapter import BaseLLMAdapter
from ..types import (
LLMConfig,
LLMRequest,
LLMResponse,
LLMUsage,
LLMProvider,
LLMError,
DEFAULT_BASE_URLS,
)
class LiteLLMAdapter(BaseLLMAdapter):
"""
LiteLLM 统一适配器
支持的提供商:
- OpenAI (openai/gpt-4o-mini)
- Claude (anthropic/claude-3-5-sonnet-20241022)
- Gemini (gemini/gemini-1.5-flash)
- DeepSeek (deepseek/deepseek-chat)
- Qwen (qwen/qwen-turbo) - 通过 OpenAI 兼容模式
- Zhipu (zhipu/glm-4-flash) - 通过 OpenAI 兼容模式
- Moonshot (moonshot/moonshot-v1-8k) - 通过 OpenAI 兼容模式
- Ollama (ollama/llama3)
"""
# LiteLLM 模型前缀映射
PROVIDER_PREFIX_MAP = {
LLMProvider.OPENAI: "openai",
LLMProvider.CLAUDE: "anthropic",
LLMProvider.GEMINI: "gemini",
LLMProvider.DEEPSEEK: "deepseek",
LLMProvider.QWEN: "openai", # 使用 OpenAI 兼容模式
LLMProvider.ZHIPU: "openai", # 使用 OpenAI 兼容模式
LLMProvider.MOONSHOT: "openai", # 使用 OpenAI 兼容模式
LLMProvider.OLLAMA: "ollama",
}
# 需要自定义 base_url 的提供商
CUSTOM_BASE_URL_PROVIDERS = {
LLMProvider.QWEN,
LLMProvider.ZHIPU,
LLMProvider.MOONSHOT,
LLMProvider.DEEPSEEK,
}
def __init__(self, config: LLMConfig):
super().__init__(config)
self._litellm_model = self._get_litellm_model()
self._api_base = self._get_api_base()
def _get_litellm_model(self) -> str:
"""获取 LiteLLM 格式的模型名称"""
provider = self.config.provider
model = self.config.model
# 对于使用 OpenAI 兼容模式的提供商,直接使用模型名
if provider in self.CUSTOM_BASE_URL_PROVIDERS:
return model
# 对于原生支持的提供商,添加前缀
prefix = self.PROVIDER_PREFIX_MAP.get(provider, "openai")
# 检查模型名是否已经包含前缀
if "/" in model:
return model
return f"{prefix}/{model}"
def _get_api_base(self) -> Optional[str]:
"""获取 API 基础 URL"""
# 优先使用用户配置的 base_url
if self.config.base_url:
return self.config.base_url
# 对于需要自定义 base_url 的提供商,使用默认值
if self.config.provider in self.CUSTOM_BASE_URL_PROVIDERS:
return DEFAULT_BASE_URLS.get(self.config.provider)
# Ollama 使用本地地址
if self.config.provider == LLMProvider.OLLAMA:
return DEFAULT_BASE_URLS.get(LLMProvider.OLLAMA, "http://localhost:11434")
return None
async def complete(self, request: LLMRequest) -> LLMResponse:
"""使用 LiteLLM 发送请求"""
try:
await self.validate_config()
return await self.retry(lambda: self._send_request(request))
except Exception as error:
self.handle_error(error, f"LiteLLM ({self.config.provider.value}) API调用失败")
async def _send_request(self, request: LLMRequest) -> LLMResponse:
"""发送请求到 LiteLLM"""
import litellm
# 禁用 LiteLLM 的缓存,确保每次都实际调用 API
litellm.cache = None
# 构建消息
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
# 构建请求参数
kwargs: Dict[str, Any] = {
"model": self._litellm_model,
"messages": messages,
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
}
# 设置 API Key
if self.config.api_key and self.config.api_key != "ollama":
kwargs["api_key"] = self.config.api_key
# 设置 API Base URL
if self._api_base:
kwargs["api_base"] = self._api_base
# 设置超时
kwargs["timeout"] = self.config.timeout
# 对于 OpenAI 提供商,添加额外参数
if self.config.provider == LLMProvider.OPENAI:
kwargs["frequency_penalty"] = self.config.frequency_penalty
kwargs["presence_penalty"] = self.config.presence_penalty
try:
# 调用 LiteLLM
response = await litellm.acompletion(**kwargs)
except litellm.exceptions.AuthenticationError as e:
raise LLMError(f"API Key 无效或已过期: {str(e)}", self.config.provider, 401)
except litellm.exceptions.RateLimitError as e:
raise LLMError(f"API 调用频率超限: {str(e)}", self.config.provider, 429)
except litellm.exceptions.APIConnectionError as e:
raise LLMError(f"无法连接到 API 服务: {str(e)}", self.config.provider)
except litellm.exceptions.APIError as e:
raise LLMError(f"API 错误: {str(e)}", self.config.provider, getattr(e, 'status_code', None))
except Exception as e:
# 捕获其他异常并重新抛出
error_msg = str(e)
if "invalid_api_key" in error_msg.lower() or "incorrect api key" in error_msg.lower():
raise LLMError(f"API Key 无效: {error_msg}", self.config.provider, 401)
elif "authentication" in error_msg.lower():
raise LLMError(f"认证失败: {error_msg}", self.config.provider, 401)
raise
# 解析响应
if not response:
raise LLMError("API 返回空响应", self.config.provider)
choice = response.choices[0] if response.choices else None
if not choice:
raise LLMError("API响应格式异常: 缺少choices字段", self.config.provider)
usage = None
if hasattr(response, "usage") and response.usage:
usage = LLMUsage(
prompt_tokens=response.usage.prompt_tokens or 0,
completion_tokens=response.usage.completion_tokens or 0,
total_tokens=response.usage.total_tokens or 0,
)
return LLMResponse(
content=choice.message.content or "",
model=response.model,
usage=usage,
finish_reason=choice.finish_reason,
)
async def validate_config(self) -> bool:
"""验证配置"""
# Ollama 不需要 API Key
if self.config.provider == LLMProvider.OLLAMA:
if not self.config.model:
raise LLMError("未指定 Ollama 模型", LLMProvider.OLLAMA)
return True
# 其他提供商需要 API Key
if not self.config.api_key:
raise LLMError(
f"API Key未配置 ({self.config.provider.value})",
self.config.provider,
)
if not self.config.model:
raise LLMError(
f"未指定模型 ({self.config.provider.value})",
self.config.provider,
)
return True
@classmethod
def supports_provider(cls, provider: LLMProvider) -> bool:
"""检查是否支持指定的提供商"""
return provider in cls.PROVIDER_PREFIX_MAP