2025-11-26 21:11:12 +08:00
|
|
|
|
"""
|
|
|
|
|
|
百度文心一言适配器
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import httpx
|
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
from ..base_adapter import BaseLLMAdapter
|
2025-11-28 16:41:39 +08:00
|
|
|
|
from ..types import LLMConfig, LLMRequest, LLMResponse, LLMError, LLMProvider, LLMUsage
|
2025-11-26 21:11:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaiduAdapter(BaseLLMAdapter):
|
|
|
|
|
|
"""百度文心一言API适配器"""
|
|
|
|
|
|
|
|
|
|
|
|
# 模型名称到API端点的映射
|
|
|
|
|
|
MODEL_ENDPOINTS = {
|
|
|
|
|
|
"ERNIE-4.0": "completions_pro",
|
|
|
|
|
|
"ERNIE-3.5-8K": "completions",
|
|
|
|
|
|
"ERNIE-3.5-128K": "ernie-3.5-128k",
|
|
|
|
|
|
"ERNIE-Speed": "ernie_speed",
|
|
|
|
|
|
"ERNIE-Lite": "ernie-lite-8k",
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: LLMConfig):
|
|
|
|
|
|
super().__init__(config)
|
|
|
|
|
|
self._access_token: Optional[str] = None
|
|
|
|
|
|
self._base_url = config.base_url or "https://aip.baidubce.com"
|
|
|
|
|
|
|
|
|
|
|
|
async def _get_access_token(self) -> str:
|
|
|
|
|
|
"""获取百度API的access_token
|
|
|
|
|
|
|
|
|
|
|
|
注意:百度API使用API Key和Secret Key来获取access_token
|
|
|
|
|
|
这里假设api_key格式为: "api_key:secret_key"
|
|
|
|
|
|
"""
|
|
|
|
|
|
if self._access_token:
|
|
|
|
|
|
return self._access_token
|
|
|
|
|
|
|
|
|
|
|
|
# 解析API Key和Secret Key
|
|
|
|
|
|
if ":" not in self.config.api_key:
|
|
|
|
|
|
raise LLMError(
|
|
|
|
|
|
"百度API需要同时提供API Key和Secret Key,格式:api_key:secret_key",
|
|
|
|
|
|
provider="baidu"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
api_key, secret_key = self.config.api_key.split(":", 1)
|
|
|
|
|
|
|
|
|
|
|
|
url = f"{self._base_url}/oauth/2.0/token"
|
|
|
|
|
|
params = {
|
|
|
|
|
|
"grant_type": "client_credentials",
|
|
|
|
|
|
"client_id": api_key,
|
|
|
|
|
|
"client_secret": secret_key,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
|
|
|
|
response = await client.post(url, params=params)
|
|
|
|
|
|
if response.status_code != 200:
|
|
|
|
|
|
raise LLMError(
|
|
|
|
|
|
f"获取百度access_token失败: {response.text}",
|
|
|
|
|
|
provider="baidu",
|
|
|
|
|
|
status_code=response.status_code
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
data = response.json()
|
|
|
|
|
|
self._access_token = data.get("access_token")
|
|
|
|
|
|
if not self._access_token:
|
|
|
|
|
|
raise LLMError(
|
|
|
|
|
|
f"百度API返回的access_token为空: {response.text}",
|
|
|
|
|
|
provider="baidu"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return self._access_token
|
|
|
|
|
|
|
2025-11-28 16:41:39 +08:00
|
|
|
|
async def complete(self, request: LLMRequest) -> LLMResponse:
|
2025-11-26 21:11:12 +08:00
|
|
|
|
"""执行实际的API调用"""
|
2025-11-28 16:41:39 +08:00
|
|
|
|
try:
|
|
|
|
|
|
await self.validate_config()
|
|
|
|
|
|
return await self.retry(lambda: self._send_request(request))
|
|
|
|
|
|
except Exception as error:
|
|
|
|
|
|
self.handle_error(error, "百度文心一言 API调用失败")
|
|
|
|
|
|
|
|
|
|
|
|
async def _send_request(self, request: LLMRequest) -> LLMResponse:
|
|
|
|
|
|
"""发送请求"""
|
2025-11-26 21:11:12 +08:00
|
|
|
|
access_token = await self._get_access_token()
|
|
|
|
|
|
|
|
|
|
|
|
# 获取模型对应的API端点
|
|
|
|
|
|
model = self.config.model or "ERNIE-3.5-8K"
|
|
|
|
|
|
endpoint = self.MODEL_ENDPOINTS.get(model, "completions")
|
|
|
|
|
|
|
|
|
|
|
|
url = f"{self._base_url}/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{endpoint}?access_token={access_token}"
|
|
|
|
|
|
|
|
|
|
|
|
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
|
|
|
|
|
|
|
|
|
|
|
payload = {
|
|
|
|
|
|
"messages": messages,
|
2025-11-28 16:41:39 +08:00
|
|
|
|
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
|
|
|
|
|
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
2025-11-26 21:11:12 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if request.max_tokens or self.config.max_tokens:
|
|
|
|
|
|
payload["max_output_tokens"] = request.max_tokens or self.config.max_tokens
|
|
|
|
|
|
|
2025-11-28 16:41:39 +08:00
|
|
|
|
response = await self.client.post(
|
|
|
|
|
|
url,
|
|
|
|
|
|
headers=self.build_headers(),
|
|
|
|
|
|
json=payload
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if response.status_code != 200:
|
|
|
|
|
|
error_data = response.json() if response.text else {}
|
|
|
|
|
|
error_msg = error_data.get("error_msg", f"HTTP {response.status_code}")
|
|
|
|
|
|
raise Exception(f"{error_msg}")
|
|
|
|
|
|
|
|
|
|
|
|
data = response.json()
|
|
|
|
|
|
|
|
|
|
|
|
if "error_code" in data:
|
|
|
|
|
|
raise Exception(f"百度API错误: {data.get('error_msg', '未知错误')}")
|
|
|
|
|
|
|
|
|
|
|
|
usage = None
|
|
|
|
|
|
if "usage" in data:
|
|
|
|
|
|
usage = LLMUsage(
|
|
|
|
|
|
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
|
|
|
|
|
completion_tokens=data["usage"].get("completion_tokens", 0),
|
|
|
|
|
|
total_tokens=data["usage"].get("total_tokens", 0)
|
2025-11-26 21:11:12 +08:00
|
|
|
|
)
|
2025-11-28 16:41:39 +08:00
|
|
|
|
|
|
|
|
|
|
return LLMResponse(
|
|
|
|
|
|
content=data.get("result", ""),
|
|
|
|
|
|
model=model,
|
|
|
|
|
|
usage=usage,
|
|
|
|
|
|
finish_reason=data.get("finish_reason")
|
|
|
|
|
|
)
|
2025-11-26 21:11:12 +08:00
|
|
|
|
|
|
|
|
|
|
async def validate_config(self) -> bool:
|
|
|
|
|
|
"""验证配置是否有效"""
|
2025-11-28 16:41:39 +08:00
|
|
|
|
if not self.config.api_key:
|
|
|
|
|
|
raise LLMError(
|
|
|
|
|
|
"API Key未配置",
|
|
|
|
|
|
provider=LLMProvider.BAIDU
|
|
|
|
|
|
)
|
|
|
|
|
|
if ":" not in self.config.api_key:
|
|
|
|
|
|
raise LLMError(
|
|
|
|
|
|
"百度API需要同时提供API Key和Secret Key,格式:api_key:secret_key",
|
|
|
|
|
|
provider=LLMProvider.BAIDU
|
|
|
|
|
|
)
|
|
|
|
|
|
return True
|
2025-11-26 21:11:12 +08:00
|
|
|
|
|
2025-11-27 18:01:57 +08:00
|
|
|
|
|
2025-12-11 19:09:10 +08:00
|
|
|
|
|
2025-12-11 20:33:46 +08:00
|
|
|
|
|
|
|
|
|
|
|