CodeReview/backend/app/services/llm/adapters/litellm_adapter.py

563 lines
24 KiB
Python
Raw Normal View History

"""
LiteLLM 统一适配器
支持通过 LiteLLM 调用多个 LLM 提供商使用统一的 OpenAI 兼容格式
增强功能:
- Prompt Caching: 为支持的 LLM Claude添加缓存标记
- 智能重试: 指数退避重试策略
- 流式输出: 支持逐 token 返回
"""
import logging
from typing import Dict, Any, Optional, List
from ..base_adapter import BaseLLMAdapter
from ..types import (
LLMConfig,
LLMRequest,
LLMResponse,
LLMUsage,
LLMProvider,
LLMError,
DEFAULT_BASE_URLS,
)
from ..prompt_cache import prompt_cache_manager, estimate_tokens
logger = logging.getLogger(__name__)
class LiteLLMAdapter(BaseLLMAdapter):
"""
LiteLLM 统一适配器
支持的提供商:
- OpenAI (openai/gpt-4o-mini)
- Claude (anthropic/claude-3-5-sonnet-20241022)
- Gemini (gemini/gemini-1.5-flash)
- DeepSeek (deepseek/deepseek-chat)
- Qwen (qwen/qwen-turbo) - 通过 OpenAI 兼容模式
- Zhipu (zhipu/glm-4-flash) - 通过 OpenAI 兼容模式
- Moonshot (moonshot/moonshot-v1-8k) - 通过 OpenAI 兼容模式
- Ollama (ollama/llama3)
"""
# LiteLLM 模型前缀映射
PROVIDER_PREFIX_MAP = {
LLMProvider.OPENAI: "openai",
LLMProvider.CLAUDE: "anthropic",
LLMProvider.GEMINI: "gemini",
LLMProvider.DEEPSEEK: "deepseek",
LLMProvider.QWEN: "openai", # 使用 OpenAI 兼容模式
LLMProvider.ZHIPU: "openai", # 使用 OpenAI 兼容模式
LLMProvider.MOONSHOT: "openai", # 使用 OpenAI 兼容模式
LLMProvider.OLLAMA: "ollama",
}
# 需要自定义 base_url 的提供商
CUSTOM_BASE_URL_PROVIDERS = {
LLMProvider.QWEN,
LLMProvider.ZHIPU,
LLMProvider.MOONSHOT,
LLMProvider.DEEPSEEK,
}
def __init__(self, config: LLMConfig):
super().__init__(config)
self._litellm_model = self._get_litellm_model()
self._api_base = self._get_api_base()
def _get_litellm_model(self) -> str:
"""获取 LiteLLM 格式的模型名称
对于使用第三方 OpenAI 兼容 API SiliconFlow的情况
- 如果用户设置了自定义 base_url且模型名包含 / ( Qwen/Qwen3-8B)
- 需要将其转换为 openai/Qwen/Qwen3-8B 格式
- 因为 LiteLLM 只认识 openai 作为有效前缀
"""
provider = self.config.provider
model = self.config.model
# 检查模型名是否已经包含前缀
if "/" in model:
# 提取第一部分作为可能的 provider 前缀
prefix_part = model.split("/")[0].lower()
# LiteLLM 认识的有效 provider 前缀列表
valid_litellm_prefixes = [
"openai", "anthropic", "gemini", "deepseek", "ollama",
"azure", "huggingface", "together", "groq", "mistral",
"anyscale", "replicate", "bedrock", "vertex_ai", "cohere",
"sagemaker", "palm", "ai21", "nlp_cloud", "aleph_alpha",
"petals", "baseten", "vllm", "cloudflare", "xinference"
]
# 如果前缀是 LiteLLM 认识的,直接返回
if prefix_part in valid_litellm_prefixes:
return model
# 如果用户设置了自定义 base_url将其视为 OpenAI 兼容 API
# 例如 SiliconFlow 使用模型名 "Qwen/Qwen3-8B"
if self.config.base_url:
logger.debug(f"使用自定义 base_url将模型 {model} 视为 OpenAI 兼容格式")
return f"openai/{model}"
# 对于没有自定义 base_url 的情况,尝试使用 provider 的前缀
prefix = self.PROVIDER_PREFIX_MAP.get(provider, "openai")
return f"{prefix}/{model}"
# 获取 provider 前缀
prefix = self.PROVIDER_PREFIX_MAP.get(provider, "openai")
return f"{prefix}/{model}"
def _extract_api_response(self, error: Exception) -> Optional[str]:
"""从异常中提取 API 服务器返回的原始响应信息"""
error_str = str(error)
# 尝试提取 JSON 格式的错误信息
import re
import json
# 匹配 {'error': {...}} 或 {"error": {...}} 格式
json_pattern = r"\{['\"]error['\"]:\s*\{[^}]+\}\}"
match = re.search(json_pattern, error_str)
if match:
try:
# 将单引号替换为双引号以便 JSON 解析
json_str = match.group().replace("'", '"')
error_obj = json.loads(json_str)
if 'error' in error_obj:
err = error_obj['error']
code = err.get('code', '')
message = err.get('message', '')
return f"[{code}] {message}" if code else message
except:
pass
# 尝试提取 message 字段
message_pattern = r"['\"]message['\"]:\s*['\"]([^'\"]+)['\"]"
match = re.search(message_pattern, error_str)
if match:
return match.group(1)
# 尝试从 litellm 异常中获取原始消息
if hasattr(error, 'message'):
return error.message
if hasattr(error, 'llm_provider'):
# litellm 异常通常包含原始错误信息
return error_str.split(' - ')[-1] if ' - ' in error_str else None
return None
def _get_api_base(self) -> Optional[str]:
"""获取 API 基础 URL"""
# 优先使用用户配置的 base_url
if self.config.base_url:
return self.config.base_url
# 对于需要自定义 base_url 的提供商,使用默认值
if self.config.provider in self.CUSTOM_BASE_URL_PROVIDERS:
return DEFAULT_BASE_URLS.get(self.config.provider)
# Ollama 使用本地地址
if self.config.provider == LLMProvider.OLLAMA:
return DEFAULT_BASE_URLS.get(LLMProvider.OLLAMA, "http://localhost:11434")
return None
async def complete(self, request: LLMRequest) -> LLMResponse:
"""使用 LiteLLM 发送请求"""
try:
await self.validate_config()
return await self.retry(lambda: self._send_request(request))
except Exception as error:
self.handle_error(error, f"LiteLLM ({self.config.provider.value}) API调用失败")
async def _send_request(self, request: LLMRequest) -> LLMResponse:
"""发送请求到 LiteLLM"""
import litellm
# 启用 LiteLLM 调试模式以获取更详细的错误信息
# 注释掉下一行可关闭调试模式
# litellm._turn_on_debug()
# 禁用 LiteLLM 的缓存,确保每次都实际调用 API
litellm.cache = None
# 禁用 LiteLLM 自动添加的 reasoning_effort 参数
# 这可以防止模型名称被错误解析为 effort 参数
litellm.drop_params = True
# 构建消息
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
# 🔥 Prompt Caching: 为支持的 LLM 添加缓存标记
cache_enabled = False
if self.config.provider == LLMProvider.CLAUDE:
# 估算系统提示词 token 数
system_tokens = 0
for msg in messages:
if msg.get("role") == "system":
system_tokens += estimate_tokens(msg.get("content", ""))
messages, cache_enabled = prompt_cache_manager.process_messages(
messages=messages,
model=self.config.model,
provider=self.config.provider.value,
system_prompt_tokens=system_tokens,
)
if cache_enabled:
logger.debug(f"🔥 Prompt Caching enabled for {self.config.model}")
# 构建请求参数
kwargs: Dict[str, Any] = {
"model": self._litellm_model,
"messages": messages,
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
}
# 设置 API Key
if self.config.api_key and self.config.api_key != "ollama":
kwargs["api_key"] = self.config.api_key
# 设置 API Base URL
if self._api_base:
kwargs["api_base"] = self._api_base
logger.debug(f"🔗 使用自定义 API Base: {self._api_base}")
# 设置超时
kwargs["timeout"] = self.config.timeout
# 对于 OpenAI 提供商,添加额外参数
if self.config.provider == LLMProvider.OPENAI:
kwargs["frequency_penalty"] = self.config.frequency_penalty
kwargs["presence_penalty"] = self.config.presence_penalty
try:
# 调用 LiteLLM
response = await litellm.acompletion(**kwargs)
except litellm.exceptions.AuthenticationError as e:
api_response = self._extract_api_response(e)
raise LLMError(f"API Key 无效或已过期", self.config.provider, 401, api_response=api_response)
except litellm.exceptions.RateLimitError as e:
error_msg = str(e)
api_response = self._extract_api_response(e)
# 区分"余额不足"和"频率超限"
if any(keyword in error_msg for keyword in ["余额不足", "资源包", "充值", "quota", "insufficient", "balance"]):
raise LLMError(f"账户余额不足或配额已用尽,请充值后重试", self.config.provider, 402, api_response=api_response)
raise LLMError(f"API 调用频率超限,请稍后重试", self.config.provider, 429, api_response=api_response)
except litellm.exceptions.APIConnectionError as e:
api_response = self._extract_api_response(e)
raise LLMError(f"无法连接到 API 服务", self.config.provider, api_response=api_response)
except (litellm.exceptions.ServiceUnavailableError, litellm.exceptions.InternalServerError) as e:
api_response = self._extract_api_response(e)
raise LLMError(f"API 服务暂时不可用 ({type(e).__name__})", self.config.provider, 503, api_response=api_response)
except litellm.exceptions.APIError as e:
api_response = self._extract_api_response(e)
raise LLMError(f"API 错误", self.config.provider, getattr(e, 'status_code', None), api_response=api_response)
except Exception as e:
# 捕获其他异常并重新抛出
error_msg = str(e)
api_response = self._extract_api_response(e)
if "invalid_api_key" in error_msg.lower() or "incorrect api key" in error_msg.lower():
raise LLMError(f"API Key 无效", self.config.provider, 401, api_response=api_response)
elif "authentication" in error_msg.lower():
raise LLMError(f"认证失败", self.config.provider, 401, api_response=api_response)
elif any(keyword in error_msg for keyword in ["余额不足", "资源包", "充值", "quota", "insufficient", "balance"]):
raise LLMError(f"账户余额不足或配额已用尽", self.config.provider, 402, api_response=api_response)
raise
# 解析响应
if not response:
raise LLMError("API 返回空响应", self.config.provider)
choice = response.choices[0] if response.choices else None
if not choice:
raise LLMError("API响应格式异常: 缺少choices字段", self.config.provider)
usage = None
if hasattr(response, "usage") and response.usage:
usage = LLMUsage(
prompt_tokens=response.usage.prompt_tokens or 0,
completion_tokens=response.usage.completion_tokens or 0,
total_tokens=response.usage.total_tokens or 0,
)
# 🔥 更新 Prompt Cache 统计
if cache_enabled and hasattr(response.usage, "cache_creation_input_tokens"):
prompt_cache_manager.update_stats(
cache_creation_input_tokens=getattr(response.usage, "cache_creation_input_tokens", 0),
cache_read_input_tokens=getattr(response.usage, "cache_read_input_tokens", 0),
total_input_tokens=response.usage.prompt_tokens or 0,
)
return LLMResponse(
content=choice.message.content or "",
model=response.model,
usage=usage,
finish_reason=choice.finish_reason,
)
async def stream_complete(self, request: LLMRequest):
"""
流式调用 LLM token 返回
Yields:
dict: {"type": "token", "content": str} {"type": "done", "content": str, "usage": dict}
"""
import litellm
await self.validate_config()
litellm.cache = None
litellm.drop_params = True
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
# 🔥 估算输入 token 数量(用于在无法获取真实 usage 时进行估算)
input_tokens_estimate = sum(estimate_tokens(msg["content"]) for msg in messages)
kwargs = {
"model": self._litellm_model,
"messages": messages,
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
"stream": True, # 启用流式输出
}
# 🔥 对于支持的模型,请求在流式输出中包含 usage 信息
# OpenAI API 支持 stream_options
if self.config.provider in [LLMProvider.OPENAI, LLMProvider.DEEPSEEK]:
kwargs["stream_options"] = {"include_usage": True}
if self.config.api_key and self.config.api_key != "ollama":
kwargs["api_key"] = self.config.api_key
if self._api_base:
kwargs["api_base"] = self._api_base
kwargs["timeout"] = self.config.timeout
accumulated_content = ""
final_usage = None # 🔥 存储最终的 usage 信息
chunk_count = 0 # 🔥 跟踪 chunk 数量
try:
response = await litellm.acompletion(**kwargs)
async for chunk in response:
chunk_count += 1
# 🔥 检查是否有 usage 信息(某些 API 会在最后的 chunk 中包含)
if hasattr(chunk, "usage") and chunk.usage:
final_usage = {
"prompt_tokens": chunk.usage.prompt_tokens or 0,
"completion_tokens": chunk.usage.completion_tokens or 0,
"total_tokens": chunk.usage.total_tokens or 0,
}
logger.debug(f"Got usage from chunk: {final_usage}")
if not chunk.choices:
# 🔥 某些模型可能发送没有 choices 的 chunk如心跳
continue
delta = chunk.choices[0].delta
content = getattr(delta, "content", "") or ""
finish_reason = chunk.choices[0].finish_reason
if content:
accumulated_content += content
yield {
"type": "token",
"content": content,
"accumulated": accumulated_content,
}
# 🔥 ENHANCED: 处理没有 content 但也没有 finish_reason 的情况
# 某些模型(如智谱 GLM可能在某些 chunk 中不返回内容
if finish_reason:
# 流式完成
# 🔥 如果没有从 chunk 获取到 usage进行估算
if not final_usage:
output_tokens_estimate = estimate_tokens(accumulated_content)
final_usage = {
"prompt_tokens": input_tokens_estimate,
"completion_tokens": output_tokens_estimate,
"total_tokens": input_tokens_estimate + output_tokens_estimate,
}
logger.debug(f"Estimated usage: {final_usage}")
# 🔥 ENHANCED: 如果累积内容为空但有 finish_reason记录警告
if not accumulated_content:
logger.warning(f"Stream completed with no content after {chunk_count} chunks, finish_reason={finish_reason}")
yield {
"type": "done",
"content": accumulated_content,
"usage": final_usage,
"finish_reason": finish_reason,
}
break
# 🔥 ENHANCED: 如果循环结束但没有收到 finish_reason也需要返回 done
if accumulated_content:
logger.warning(f"Stream ended without finish_reason, returning accumulated content ({len(accumulated_content)} chars)")
if not final_usage:
output_tokens_estimate = estimate_tokens(accumulated_content)
final_usage = {
"prompt_tokens": input_tokens_estimate,
"completion_tokens": output_tokens_estimate,
"total_tokens": input_tokens_estimate + output_tokens_estimate,
}
yield {
"type": "done",
"content": accumulated_content,
"usage": final_usage,
"finish_reason": "complete",
}
except litellm.exceptions.RateLimitError as e:
# 速率限制错误 - 需要特殊处理
logger.error(f"Stream rate limit error: {e}")
error_msg = str(e)
# 区分"余额不足"和"频率超限"
if any(keyword in error_msg.lower() for keyword in ["余额不足", "资源包", "充值", "quota", "exceeded", "billing"]):
error_type = "quota_exceeded"
user_message = "API 配额已用尽,请检查账户余额或升级计划"
else:
error_type = "rate_limit"
# 尝试从错误消息中提取重试时间
import re
retry_match = re.search(r"retry\s*(?:in|after)\s*(\d+(?:\.\d+)?)\s*s", error_msg, re.IGNORECASE)
retry_seconds = float(retry_match.group(1)) if retry_match else 60
user_message = f"API 调用频率超限,建议等待 {int(retry_seconds)} 秒后重试"
output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0
yield {
"type": "error",
"error_type": error_type,
"error": error_msg,
"user_message": user_message,
"accumulated": accumulated_content,
"usage": {
"prompt_tokens": input_tokens_estimate,
"completion_tokens": output_tokens_estimate,
"total_tokens": input_tokens_estimate + output_tokens_estimate,
} if accumulated_content else None,
}
except litellm.exceptions.AuthenticationError as e:
# 认证错误 - API Key 无效
logger.error(f"Stream authentication error: {e}")
yield {
"type": "error",
"error_type": "authentication",
"error": str(e),
"user_message": "API Key 无效或已过期,请检查配置",
"accumulated": accumulated_content,
"usage": None,
}
except litellm.exceptions.APIConnectionError as e:
# 连接错误 - 网络问题
logger.error(f"Stream connection error: {e}")
yield {
"type": "error",
"error_type": "connection",
"error": str(e),
"user_message": "无法连接到 API 服务,请检查网络连接",
"accumulated": accumulated_content,
"usage": None,
}
except (litellm.exceptions.ServiceUnavailableError, litellm.exceptions.InternalServerError) as e:
# 服务不可用 - 服务器端 5xx 错误
logger.error(f"Stream server error ({type(e).__name__}): {e}")
yield {
"type": "error",
"error_type": "server_error",
"error": str(e),
"user_message": f"API 服务暂时不可用 ({type(e).__name__})",
"accumulated": accumulated_content,
"usage": None,
}
except Exception as e:
# 其他错误 - 检查是否是包装的速率限制错误
error_msg = str(e)
logger.error(f"Stream error: {e}")
# 检查是否是包装的速率限制错误(如 ServiceUnavailableError 包装 RateLimitError
is_rate_limit = any(keyword in error_msg.lower() for keyword in [
"ratelimiterror", "rate limit", "429", "resource_exhausted",
"quota exceeded", "too many requests"
])
if is_rate_limit:
# 按速率限制错误处理
import re
# 检查是否是配额用尽
if any(keyword in error_msg.lower() for keyword in ["quota", "exceeded", "billing"]):
error_type = "quota_exceeded"
user_message = "API 配额已用尽,请检查账户余额或升级计划"
else:
error_type = "rate_limit"
retry_match = re.search(r"retry\s*(?:in|after)\s*(\d+(?:\.\d+)?)\s*s", error_msg, re.IGNORECASE)
retry_seconds = float(retry_match.group(1)) if retry_match else 60
user_message = f"API 调用频率超限,建议等待 {int(retry_seconds)} 秒后重试"
else:
error_type = "unknown"
user_message = "LLM 调用发生错误,请重试"
output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0
yield {
"type": "error",
"error_type": error_type,
"error": error_msg,
"user_message": user_message,
"accumulated": accumulated_content,
"usage": {
"prompt_tokens": input_tokens_estimate,
"completion_tokens": output_tokens_estimate,
"total_tokens": input_tokens_estimate + output_tokens_estimate,
} if accumulated_content else None,
}
async def validate_config(self) -> bool:
"""验证配置"""
# Ollama 不需要 API Key
if self.config.provider == LLMProvider.OLLAMA:
if not self.config.model:
raise LLMError("未指定 Ollama 模型", LLMProvider.OLLAMA)
return True
# 其他提供商需要 API Key
if not self.config.api_key:
raise LLMError(
f"API Key未配置 ({self.config.provider.value})",
self.config.provider,
)
# check for placeholder keys
if "sk-your-" in self.config.api_key or "***" in self.config.api_key:
raise LLMError(
f"无效的 API Key (使用了占位符): {self.config.api_key[:10]}...",
self.config.provider,
401
)
if not self.config.model:
raise LLMError(
f"未指定模型 ({self.config.provider.value})",
self.config.provider,
)
return True
@classmethod
def supports_provider(cls, provider: LLMProvider) -> bool:
"""检查是否支持指定的提供商"""
return provider in cls.PROVIDER_PREFIX_MAP