175 lines
5.2 KiB
Python
175 lines
5.2 KiB
Python
"""
|
||
LLM工厂类 - 统一创建和管理LLM适配器
|
||
|
||
使用 LiteLLM 作为主要适配器,支持大多数 LLM 提供商。
|
||
对于 API 格式特殊的提供商(百度、MiniMax、豆包),使用原生适配器。
|
||
"""
|
||
|
||
from typing import Dict, List
|
||
from .types import LLMConfig, LLMProvider, DEFAULT_MODELS
|
||
from .base_adapter import BaseLLMAdapter
|
||
from .adapters import (
|
||
LiteLLMAdapter,
|
||
BaiduAdapter,
|
||
MinimaxAdapter,
|
||
DoubaoAdapter,
|
||
)
|
||
|
||
|
||
# 必须使用原生适配器的提供商(API 格式特殊)
|
||
NATIVE_ONLY_PROVIDERS = {
|
||
LLMProvider.BAIDU,
|
||
LLMProvider.MINIMAX,
|
||
LLMProvider.DOUBAO,
|
||
}
|
||
|
||
|
||
class LLMFactory:
|
||
"""LLM工厂类"""
|
||
|
||
_adapters: Dict[str, BaseLLMAdapter] = {}
|
||
|
||
@classmethod
|
||
def create_adapter(cls, config: LLMConfig) -> BaseLLMAdapter:
|
||
"""创建LLM适配器实例"""
|
||
cache_key = cls._get_cache_key(config)
|
||
|
||
# 从缓存中获取
|
||
if cache_key in cls._adapters:
|
||
return cls._adapters[cache_key]
|
||
|
||
# 创建新的适配器实例
|
||
adapter = cls._instantiate_adapter(config)
|
||
|
||
# 缓存实例
|
||
cls._adapters[cache_key] = adapter
|
||
|
||
return adapter
|
||
|
||
@classmethod
|
||
def _instantiate_adapter(cls, config: LLMConfig) -> BaseLLMAdapter:
|
||
"""根据提供商类型实例化适配器"""
|
||
# 如果未指定模型,使用默认模型
|
||
if not config.model:
|
||
config.model = DEFAULT_MODELS.get(config.provider, "gpt-4o-mini")
|
||
|
||
# 对于必须使用原生适配器的提供商
|
||
if config.provider in NATIVE_ONLY_PROVIDERS:
|
||
return cls._create_native_adapter(config)
|
||
|
||
# 其他提供商使用 LiteLLM
|
||
if LiteLLMAdapter.supports_provider(config.provider):
|
||
return LiteLLMAdapter(config)
|
||
|
||
# 不支持的提供商
|
||
raise ValueError(f"不支持的LLM提供商: {config.provider}")
|
||
|
||
@classmethod
|
||
def _create_native_adapter(cls, config: LLMConfig) -> BaseLLMAdapter:
|
||
"""创建原生适配器(仅用于 API 格式特殊的提供商)"""
|
||
native_adapter_map = {
|
||
LLMProvider.BAIDU: BaiduAdapter,
|
||
LLMProvider.MINIMAX: MinimaxAdapter,
|
||
LLMProvider.DOUBAO: DoubaoAdapter,
|
||
}
|
||
|
||
adapter_class = native_adapter_map.get(config.provider)
|
||
if not adapter_class:
|
||
raise ValueError(f"不支持的原生适配器提供商: {config.provider}")
|
||
|
||
return adapter_class(config)
|
||
|
||
@classmethod
|
||
def _get_cache_key(cls, config: LLMConfig) -> str:
|
||
"""生成缓存键"""
|
||
api_key_prefix = config.api_key[:8] if config.api_key else "no-key"
|
||
return f"{config.provider.value}:{config.model}:{api_key_prefix}"
|
||
|
||
@classmethod
|
||
def clear_cache(cls) -> None:
|
||
"""清除缓存"""
|
||
cls._adapters.clear()
|
||
|
||
@classmethod
|
||
def get_supported_providers(cls) -> List[LLMProvider]:
|
||
"""获取支持的提供商列表"""
|
||
return list(LLMProvider)
|
||
|
||
@classmethod
|
||
def get_default_model(cls, provider: LLMProvider) -> str:
|
||
"""获取提供商的默认模型"""
|
||
return DEFAULT_MODELS.get(provider, "gpt-4o-mini")
|
||
|
||
@classmethod
|
||
def get_available_models(cls, provider: LLMProvider) -> List[str]:
|
||
"""获取提供商的可用模型列表"""
|
||
models = {
|
||
LLMProvider.GEMINI: [
|
||
"gemini-2.5-flash",
|
||
"gemini-2.5-pro",
|
||
"gemini-1.5-flash",
|
||
"gemini-1.5-pro",
|
||
],
|
||
LLMProvider.OPENAI: [
|
||
"gpt-4o",
|
||
"gpt-4o-mini",
|
||
"gpt-4-turbo",
|
||
"gpt-4",
|
||
"gpt-3.5-turbo",
|
||
"o1-preview",
|
||
"o1-mini",
|
||
],
|
||
LLMProvider.CLAUDE: [
|
||
"claude-3-5-sonnet-20241022",
|
||
"claude-3-opus-20240229",
|
||
"claude-3-sonnet-20240229",
|
||
"claude-3-haiku-20240307",
|
||
],
|
||
LLMProvider.QWEN: [
|
||
"qwen-turbo",
|
||
"qwen-plus",
|
||
"qwen-max",
|
||
"qwen-max-longcontext",
|
||
],
|
||
LLMProvider.DEEPSEEK: [
|
||
"deepseek-chat",
|
||
"deepseek-coder",
|
||
],
|
||
LLMProvider.ZHIPU: [
|
||
"glm-4-flash",
|
||
"glm-4",
|
||
"glm-4-plus",
|
||
"glm-4-long",
|
||
],
|
||
LLMProvider.MOONSHOT: [
|
||
"moonshot-v1-8k",
|
||
"moonshot-v1-32k",
|
||
"moonshot-v1-128k",
|
||
],
|
||
LLMProvider.BAIDU: [
|
||
"ERNIE-3.5-8K",
|
||
"ERNIE-4.0-8K",
|
||
"ERNIE-Speed-8K",
|
||
],
|
||
LLMProvider.MINIMAX: [
|
||
"abab6.5-chat",
|
||
"abab6.5s-chat",
|
||
"abab5.5-chat",
|
||
],
|
||
LLMProvider.DOUBAO: [
|
||
"doubao-pro-32k",
|
||
"doubao-pro-128k",
|
||
"doubao-lite-32k",
|
||
],
|
||
LLMProvider.OLLAMA: [
|
||
"llama3",
|
||
"llama3.1",
|
||
"llama3.2",
|
||
"codellama",
|
||
"mistral",
|
||
"deepseek-coder-v2",
|
||
"qwen2.5-coder",
|
||
],
|
||
}
|
||
return models.get(provider, [])
|