CodeReview/backend/app/services/llm/adapters/litellm_adapter.py

563 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LiteLLM 统一适配器
支持通过 LiteLLM 调用多个 LLM 提供商,使用统一的 OpenAI 兼容格式
增强功能:
- Prompt Caching: 为支持的 LLM如 Claude添加缓存标记
- 智能重试: 指数退避重试策略
- 流式输出: 支持逐 token 返回
"""
import logging
from typing import Dict, Any, Optional, List
from ..base_adapter import BaseLLMAdapter
from ..types import (
LLMConfig,
LLMRequest,
LLMResponse,
LLMUsage,
LLMProvider,
LLMError,
DEFAULT_BASE_URLS,
)
from ..prompt_cache import prompt_cache_manager, estimate_tokens
logger = logging.getLogger(__name__)
class LiteLLMAdapter(BaseLLMAdapter):
"""
LiteLLM 统一适配器
支持的提供商:
- OpenAI (openai/gpt-4o-mini)
- Claude (anthropic/claude-3-5-sonnet-20241022)
- Gemini (gemini/gemini-1.5-flash)
- DeepSeek (deepseek/deepseek-chat)
- Qwen (qwen/qwen-turbo) - 通过 OpenAI 兼容模式
- Zhipu (zhipu/glm-4-flash) - 通过 OpenAI 兼容模式
- Moonshot (moonshot/moonshot-v1-8k) - 通过 OpenAI 兼容模式
- Ollama (ollama/llama3)
"""
# LiteLLM 模型前缀映射
PROVIDER_PREFIX_MAP = {
LLMProvider.OPENAI: "openai",
LLMProvider.CLAUDE: "anthropic",
LLMProvider.GEMINI: "gemini",
LLMProvider.DEEPSEEK: "deepseek",
LLMProvider.QWEN: "openai", # 使用 OpenAI 兼容模式
LLMProvider.ZHIPU: "openai", # 使用 OpenAI 兼容模式
LLMProvider.MOONSHOT: "openai", # 使用 OpenAI 兼容模式
LLMProvider.OLLAMA: "ollama",
}
# 需要自定义 base_url 的提供商
CUSTOM_BASE_URL_PROVIDERS = {
LLMProvider.QWEN,
LLMProvider.ZHIPU,
LLMProvider.MOONSHOT,
LLMProvider.DEEPSEEK,
}
def __init__(self, config: LLMConfig):
super().__init__(config)
self._litellm_model = self._get_litellm_model()
self._api_base = self._get_api_base()
def _get_litellm_model(self) -> str:
"""获取 LiteLLM 格式的模型名称
对于使用第三方 OpenAI 兼容 API如 SiliconFlow的情况
- 如果用户设置了自定义 base_url且模型名包含 / (如 Qwen/Qwen3-8B)
- 需要将其转换为 openai/Qwen/Qwen3-8B 格式
- 因为 LiteLLM 只认识 openai 作为有效前缀
"""
provider = self.config.provider
model = self.config.model
# 检查模型名是否已经包含前缀
if "/" in model:
# 提取第一部分作为可能的 provider 前缀
prefix_part = model.split("/")[0].lower()
# LiteLLM 认识的有效 provider 前缀列表
valid_litellm_prefixes = [
"openai", "anthropic", "gemini", "deepseek", "ollama",
"azure", "huggingface", "together", "groq", "mistral",
"anyscale", "replicate", "bedrock", "vertex_ai", "cohere",
"sagemaker", "palm", "ai21", "nlp_cloud", "aleph_alpha",
"petals", "baseten", "vllm", "cloudflare", "xinference"
]
# 如果前缀是 LiteLLM 认识的,直接返回
if prefix_part in valid_litellm_prefixes:
return model
# 如果用户设置了自定义 base_url将其视为 OpenAI 兼容 API
# 例如 SiliconFlow 使用模型名 "Qwen/Qwen3-8B"
if self.config.base_url:
logger.debug(f"使用自定义 base_url将模型 {model} 视为 OpenAI 兼容格式")
return f"openai/{model}"
# 对于没有自定义 base_url 的情况,尝试使用 provider 的前缀
prefix = self.PROVIDER_PREFIX_MAP.get(provider, "openai")
return f"{prefix}/{model}"
# 获取 provider 前缀
prefix = self.PROVIDER_PREFIX_MAP.get(provider, "openai")
return f"{prefix}/{model}"
def _extract_api_response(self, error: Exception) -> Optional[str]:
"""从异常中提取 API 服务器返回的原始响应信息"""
error_str = str(error)
# 尝试提取 JSON 格式的错误信息
import re
import json
# 匹配 {'error': {...}} 或 {"error": {...}} 格式
json_pattern = r"\{['\"]error['\"]:\s*\{[^}]+\}\}"
match = re.search(json_pattern, error_str)
if match:
try:
# 将单引号替换为双引号以便 JSON 解析
json_str = match.group().replace("'", '"')
error_obj = json.loads(json_str)
if 'error' in error_obj:
err = error_obj['error']
code = err.get('code', '')
message = err.get('message', '')
return f"[{code}] {message}" if code else message
except:
pass
# 尝试提取 message 字段
message_pattern = r"['\"]message['\"]:\s*['\"]([^'\"]+)['\"]"
match = re.search(message_pattern, error_str)
if match:
return match.group(1)
# 尝试从 litellm 异常中获取原始消息
if hasattr(error, 'message'):
return error.message
if hasattr(error, 'llm_provider'):
# litellm 异常通常包含原始错误信息
return error_str.split(' - ')[-1] if ' - ' in error_str else None
return None
def _get_api_base(self) -> Optional[str]:
"""获取 API 基础 URL"""
# 优先使用用户配置的 base_url
if self.config.base_url:
return self.config.base_url
# 对于需要自定义 base_url 的提供商,使用默认值
if self.config.provider in self.CUSTOM_BASE_URL_PROVIDERS:
return DEFAULT_BASE_URLS.get(self.config.provider)
# Ollama 使用本地地址
if self.config.provider == LLMProvider.OLLAMA:
return DEFAULT_BASE_URLS.get(LLMProvider.OLLAMA, "http://localhost:11434")
return None
async def complete(self, request: LLMRequest) -> LLMResponse:
"""使用 LiteLLM 发送请求"""
try:
await self.validate_config()
return await self.retry(lambda: self._send_request(request))
except Exception as error:
self.handle_error(error, f"LiteLLM ({self.config.provider.value}) API调用失败")
async def _send_request(self, request: LLMRequest) -> LLMResponse:
"""发送请求到 LiteLLM"""
import litellm
# 启用 LiteLLM 调试模式以获取更详细的错误信息
# 注释掉下一行可关闭调试模式
# litellm._turn_on_debug()
# 禁用 LiteLLM 的缓存,确保每次都实际调用 API
litellm.cache = None
# 禁用 LiteLLM 自动添加的 reasoning_effort 参数
# 这可以防止模型名称被错误解析为 effort 参数
litellm.drop_params = True
# 构建消息
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
# 🔥 Prompt Caching: 为支持的 LLM 添加缓存标记
cache_enabled = False
if self.config.provider == LLMProvider.CLAUDE:
# 估算系统提示词 token 数
system_tokens = 0
for msg in messages:
if msg.get("role") == "system":
system_tokens += estimate_tokens(msg.get("content", ""))
messages, cache_enabled = prompt_cache_manager.process_messages(
messages=messages,
model=self.config.model,
provider=self.config.provider.value,
system_prompt_tokens=system_tokens,
)
if cache_enabled:
logger.debug(f"🔥 Prompt Caching enabled for {self.config.model}")
# 构建请求参数
kwargs: Dict[str, Any] = {
"model": self._litellm_model,
"messages": messages,
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
}
# 设置 API Key
if self.config.api_key and self.config.api_key != "ollama":
kwargs["api_key"] = self.config.api_key
# 设置 API Base URL
if self._api_base:
kwargs["api_base"] = self._api_base
logger.debug(f"🔗 使用自定义 API Base: {self._api_base}")
# 设置超时
kwargs["timeout"] = self.config.timeout
# 对于 OpenAI 提供商,添加额外参数
if self.config.provider == LLMProvider.OPENAI:
kwargs["frequency_penalty"] = self.config.frequency_penalty
kwargs["presence_penalty"] = self.config.presence_penalty
try:
# 调用 LiteLLM
response = await litellm.acompletion(**kwargs)
except litellm.exceptions.AuthenticationError as e:
api_response = self._extract_api_response(e)
raise LLMError(f"API Key 无效或已过期", self.config.provider, 401, api_response=api_response)
except litellm.exceptions.RateLimitError as e:
error_msg = str(e)
api_response = self._extract_api_response(e)
# 区分"余额不足"和"频率超限"
if any(keyword in error_msg for keyword in ["余额不足", "资源包", "充值", "quota", "insufficient", "balance"]):
raise LLMError(f"账户余额不足或配额已用尽,请充值后重试", self.config.provider, 402, api_response=api_response)
raise LLMError(f"API 调用频率超限,请稍后重试", self.config.provider, 429, api_response=api_response)
except litellm.exceptions.APIConnectionError as e:
api_response = self._extract_api_response(e)
raise LLMError(f"无法连接到 API 服务", self.config.provider, api_response=api_response)
except (litellm.exceptions.ServiceUnavailableError, litellm.exceptions.InternalServerError) as e:
api_response = self._extract_api_response(e)
raise LLMError(f"API 服务暂时不可用 ({type(e).__name__})", self.config.provider, 503, api_response=api_response)
except litellm.exceptions.APIError as e:
api_response = self._extract_api_response(e)
raise LLMError(f"API 错误", self.config.provider, getattr(e, 'status_code', None), api_response=api_response)
except Exception as e:
# 捕获其他异常并重新抛出
error_msg = str(e)
api_response = self._extract_api_response(e)
if "invalid_api_key" in error_msg.lower() or "incorrect api key" in error_msg.lower():
raise LLMError(f"API Key 无效", self.config.provider, 401, api_response=api_response)
elif "authentication" in error_msg.lower():
raise LLMError(f"认证失败", self.config.provider, 401, api_response=api_response)
elif any(keyword in error_msg for keyword in ["余额不足", "资源包", "充值", "quota", "insufficient", "balance"]):
raise LLMError(f"账户余额不足或配额已用尽", self.config.provider, 402, api_response=api_response)
raise
# 解析响应
if not response:
raise LLMError("API 返回空响应", self.config.provider)
choice = response.choices[0] if response.choices else None
if not choice:
raise LLMError("API响应格式异常: 缺少choices字段", self.config.provider)
usage = None
if hasattr(response, "usage") and response.usage:
usage = LLMUsage(
prompt_tokens=response.usage.prompt_tokens or 0,
completion_tokens=response.usage.completion_tokens or 0,
total_tokens=response.usage.total_tokens or 0,
)
# 🔥 更新 Prompt Cache 统计
if cache_enabled and hasattr(response.usage, "cache_creation_input_tokens"):
prompt_cache_manager.update_stats(
cache_creation_input_tokens=getattr(response.usage, "cache_creation_input_tokens", 0),
cache_read_input_tokens=getattr(response.usage, "cache_read_input_tokens", 0),
total_input_tokens=response.usage.prompt_tokens or 0,
)
return LLMResponse(
content=choice.message.content or "",
model=response.model,
usage=usage,
finish_reason=choice.finish_reason,
)
async def stream_complete(self, request: LLMRequest):
"""
流式调用 LLM逐 token 返回
Yields:
dict: {"type": "token", "content": str} 或 {"type": "done", "content": str, "usage": dict}
"""
import litellm
await self.validate_config()
litellm.cache = None
litellm.drop_params = True
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
# 🔥 估算输入 token 数量(用于在无法获取真实 usage 时进行估算)
input_tokens_estimate = sum(estimate_tokens(msg["content"]) for msg in messages)
kwargs = {
"model": self._litellm_model,
"messages": messages,
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
"stream": True, # 启用流式输出
}
# 🔥 对于支持的模型,请求在流式输出中包含 usage 信息
# OpenAI API 支持 stream_options
if self.config.provider in [LLMProvider.OPENAI, LLMProvider.DEEPSEEK]:
kwargs["stream_options"] = {"include_usage": True}
if self.config.api_key and self.config.api_key != "ollama":
kwargs["api_key"] = self.config.api_key
if self._api_base:
kwargs["api_base"] = self._api_base
kwargs["timeout"] = self.config.timeout
accumulated_content = ""
final_usage = None # 🔥 存储最终的 usage 信息
chunk_count = 0 # 🔥 跟踪 chunk 数量
try:
response = await litellm.acompletion(**kwargs)
async for chunk in response:
chunk_count += 1
# 🔥 检查是否有 usage 信息(某些 API 会在最后的 chunk 中包含)
if hasattr(chunk, "usage") and chunk.usage:
final_usage = {
"prompt_tokens": chunk.usage.prompt_tokens or 0,
"completion_tokens": chunk.usage.completion_tokens or 0,
"total_tokens": chunk.usage.total_tokens or 0,
}
logger.debug(f"Got usage from chunk: {final_usage}")
if not chunk.choices:
# 🔥 某些模型可能发送没有 choices 的 chunk如心跳
continue
delta = chunk.choices[0].delta
content = getattr(delta, "content", "") or ""
finish_reason = chunk.choices[0].finish_reason
if content:
accumulated_content += content
yield {
"type": "token",
"content": content,
"accumulated": accumulated_content,
}
# 🔥 ENHANCED: 处理没有 content 但也没有 finish_reason 的情况
# 某些模型(如智谱 GLM可能在某些 chunk 中不返回内容
if finish_reason:
# 流式完成
# 🔥 如果没有从 chunk 获取到 usage进行估算
if not final_usage:
output_tokens_estimate = estimate_tokens(accumulated_content)
final_usage = {
"prompt_tokens": input_tokens_estimate,
"completion_tokens": output_tokens_estimate,
"total_tokens": input_tokens_estimate + output_tokens_estimate,
}
logger.debug(f"Estimated usage: {final_usage}")
# 🔥 ENHANCED: 如果累积内容为空但有 finish_reason记录警告
if not accumulated_content:
logger.warning(f"Stream completed with no content after {chunk_count} chunks, finish_reason={finish_reason}")
yield {
"type": "done",
"content": accumulated_content,
"usage": final_usage,
"finish_reason": finish_reason,
}
break
# 🔥 ENHANCED: 如果循环结束但没有收到 finish_reason也需要返回 done
if accumulated_content:
logger.warning(f"Stream ended without finish_reason, returning accumulated content ({len(accumulated_content)} chars)")
if not final_usage:
output_tokens_estimate = estimate_tokens(accumulated_content)
final_usage = {
"prompt_tokens": input_tokens_estimate,
"completion_tokens": output_tokens_estimate,
"total_tokens": input_tokens_estimate + output_tokens_estimate,
}
yield {
"type": "done",
"content": accumulated_content,
"usage": final_usage,
"finish_reason": "complete",
}
except litellm.exceptions.RateLimitError as e:
# 速率限制错误 - 需要特殊处理
logger.error(f"Stream rate limit error: {e}")
error_msg = str(e)
# 区分"余额不足"和"频率超限"
if any(keyword in error_msg.lower() for keyword in ["余额不足", "资源包", "充值", "quota", "exceeded", "billing"]):
error_type = "quota_exceeded"
user_message = "API 配额已用尽,请检查账户余额或升级计划"
else:
error_type = "rate_limit"
# 尝试从错误消息中提取重试时间
import re
retry_match = re.search(r"retry\s*(?:in|after)\s*(\d+(?:\.\d+)?)\s*s", error_msg, re.IGNORECASE)
retry_seconds = float(retry_match.group(1)) if retry_match else 60
user_message = f"API 调用频率超限,建议等待 {int(retry_seconds)} 秒后重试"
output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0
yield {
"type": "error",
"error_type": error_type,
"error": error_msg,
"user_message": user_message,
"accumulated": accumulated_content,
"usage": {
"prompt_tokens": input_tokens_estimate,
"completion_tokens": output_tokens_estimate,
"total_tokens": input_tokens_estimate + output_tokens_estimate,
} if accumulated_content else None,
}
except litellm.exceptions.AuthenticationError as e:
# 认证错误 - API Key 无效
logger.error(f"Stream authentication error: {e}")
yield {
"type": "error",
"error_type": "authentication",
"error": str(e),
"user_message": "API Key 无效或已过期,请检查配置",
"accumulated": accumulated_content,
"usage": None,
}
except litellm.exceptions.APIConnectionError as e:
# 连接错误 - 网络问题
logger.error(f"Stream connection error: {e}")
yield {
"type": "error",
"error_type": "connection",
"error": str(e),
"user_message": "无法连接到 API 服务,请检查网络连接",
"accumulated": accumulated_content,
"usage": None,
}
except (litellm.exceptions.ServiceUnavailableError, litellm.exceptions.InternalServerError) as e:
# 服务不可用 - 服务器端 5xx 错误
logger.error(f"Stream server error ({type(e).__name__}): {e}")
yield {
"type": "error",
"error_type": "server_error",
"error": str(e),
"user_message": f"API 服务暂时不可用 ({type(e).__name__})",
"accumulated": accumulated_content,
"usage": None,
}
except Exception as e:
# 其他错误 - 检查是否是包装的速率限制错误
error_msg = str(e)
logger.error(f"Stream error: {e}")
# 检查是否是包装的速率限制错误(如 ServiceUnavailableError 包装 RateLimitError
is_rate_limit = any(keyword in error_msg.lower() for keyword in [
"ratelimiterror", "rate limit", "429", "resource_exhausted",
"quota exceeded", "too many requests"
])
if is_rate_limit:
# 按速率限制错误处理
import re
# 检查是否是配额用尽
if any(keyword in error_msg.lower() for keyword in ["quota", "exceeded", "billing"]):
error_type = "quota_exceeded"
user_message = "API 配额已用尽,请检查账户余额或升级计划"
else:
error_type = "rate_limit"
retry_match = re.search(r"retry\s*(?:in|after)\s*(\d+(?:\.\d+)?)\s*s", error_msg, re.IGNORECASE)
retry_seconds = float(retry_match.group(1)) if retry_match else 60
user_message = f"API 调用频率超限,建议等待 {int(retry_seconds)} 秒后重试"
else:
error_type = "unknown"
user_message = "LLM 调用发生错误,请重试"
output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0
yield {
"type": "error",
"error_type": error_type,
"error": error_msg,
"user_message": user_message,
"accumulated": accumulated_content,
"usage": {
"prompt_tokens": input_tokens_estimate,
"completion_tokens": output_tokens_estimate,
"total_tokens": input_tokens_estimate + output_tokens_estimate,
} if accumulated_content else None,
}
async def validate_config(self) -> bool:
"""验证配置"""
# Ollama 不需要 API Key
if self.config.provider == LLMProvider.OLLAMA:
if not self.config.model:
raise LLMError("未指定 Ollama 模型", LLMProvider.OLLAMA)
return True
# 其他提供商需要 API Key
if not self.config.api_key:
raise LLMError(
f"API Key未配置 ({self.config.provider.value})",
self.config.provider,
)
# check for placeholder keys
if "sk-your-" in self.config.api_key or "***" in self.config.api_key:
raise LLMError(
f"无效的 API Key (使用了占位符): {self.config.api_key[:10]}...",
self.config.provider,
401
)
if not self.config.model:
raise LLMError(
f"未指定模型 ({self.config.provider.value})",
self.config.provider,
)
return True
@classmethod
def supports_provider(cls, provider: LLMProvider) -> bool:
"""检查是否支持指定的提供商"""
return provider in cls.PROVIDER_PREFIX_MAP