159 lines
4.7 KiB
Python
159 lines
4.7 KiB
Python
"""
|
||
Agent 工具基类
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from typing import Any, Dict, Optional, Type
|
||
from dataclasses import dataclass, field
|
||
from pydantic import BaseModel
|
||
import logging
|
||
import time
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class ToolResult:
|
||
"""工具执行结果"""
|
||
success: bool
|
||
data: Any = None
|
||
error: Optional[str] = None
|
||
duration_ms: int = 0
|
||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"success": self.success,
|
||
"data": self.data,
|
||
"error": self.error,
|
||
"duration_ms": self.duration_ms,
|
||
"metadata": self.metadata,
|
||
}
|
||
|
||
def to_string(self, max_length: int = 5000) -> str:
|
||
"""转换为字符串(用于 LLM 输出)"""
|
||
if not self.success:
|
||
return f"Error: {self.error}"
|
||
|
||
if isinstance(self.data, str):
|
||
result = self.data
|
||
elif isinstance(self.data, (dict, list)):
|
||
import json
|
||
result = json.dumps(self.data, ensure_ascii=False, indent=2)
|
||
else:
|
||
result = str(self.data)
|
||
|
||
if len(result) > max_length:
|
||
result = result[:max_length] + f"\n... (truncated, total {len(result)} chars)"
|
||
|
||
return result
|
||
|
||
|
||
class AgentTool(ABC):
|
||
"""
|
||
Agent 工具基类
|
||
所有工具需要继承此类并实现必要的方法
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._call_count = 0
|
||
self._total_duration_ms = 0
|
||
|
||
@property
|
||
@abstractmethod
|
||
def name(self) -> str:
|
||
"""工具名称"""
|
||
pass
|
||
|
||
@property
|
||
@abstractmethod
|
||
def description(self) -> str:
|
||
"""工具描述(用于 Agent 理解工具功能)"""
|
||
pass
|
||
|
||
@property
|
||
def args_schema(self) -> Optional[Type[BaseModel]]:
|
||
"""参数 Schema(Pydantic 模型)"""
|
||
return None
|
||
|
||
@abstractmethod
|
||
async def _execute(self, **kwargs) -> ToolResult:
|
||
"""执行工具(子类实现)"""
|
||
pass
|
||
|
||
async def execute(self, **kwargs) -> ToolResult:
|
||
"""执行工具(带计时和日志)"""
|
||
start_time = time.time()
|
||
|
||
try:
|
||
logger.debug(f"Tool '{self.name}' executing with args: {kwargs}")
|
||
result = await self._execute(**kwargs)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Tool '{self.name}' error: {e}", exc_info=True)
|
||
error_msg = str(e)
|
||
result = ToolResult(
|
||
success=False,
|
||
data=f"工具执行异常: {error_msg}", # 🔥 修复:设置 data 字段避免 None
|
||
error=error_msg,
|
||
)
|
||
|
||
duration_ms = int((time.time() - start_time) * 1000)
|
||
result.duration_ms = duration_ms
|
||
|
||
self._call_count += 1
|
||
self._total_duration_ms += duration_ms
|
||
|
||
logger.debug(f"Tool '{self.name}' completed in {duration_ms}ms, success={result.success}")
|
||
|
||
return result
|
||
|
||
def get_langchain_tool(self):
|
||
"""转换为 LangChain Tool"""
|
||
from langchain.tools import Tool, StructuredTool
|
||
import asyncio
|
||
|
||
def sync_wrapper(**kwargs):
|
||
"""同步包装器"""
|
||
loop = asyncio.get_event_loop()
|
||
if loop.is_running():
|
||
import concurrent.futures
|
||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||
future = executor.submit(asyncio.run, self.execute(**kwargs))
|
||
result = future.result()
|
||
else:
|
||
result = asyncio.run(self.execute(**kwargs))
|
||
return result.to_string()
|
||
|
||
async def async_wrapper(**kwargs):
|
||
"""异步包装器"""
|
||
result = await self.execute(**kwargs)
|
||
return result.to_string()
|
||
|
||
if self.args_schema:
|
||
return StructuredTool(
|
||
name=self.name,
|
||
description=self.description,
|
||
func=sync_wrapper,
|
||
coroutine=async_wrapper,
|
||
args_schema=self.args_schema,
|
||
)
|
||
else:
|
||
return Tool(
|
||
name=self.name,
|
||
description=self.description,
|
||
func=lambda x: sync_wrapper(query=x),
|
||
coroutine=lambda x: async_wrapper(query=x),
|
||
)
|
||
|
||
@property
|
||
def stats(self) -> Dict[str, Any]:
|
||
"""工具使用统计"""
|
||
return {
|
||
"name": self.name,
|
||
"call_count": self._call_count,
|
||
"total_duration_ms": self._total_duration_ms,
|
||
"avg_duration_ms": self._total_duration_ms // max(1, self._call_count),
|
||
}
|
||
|