157 lines
4.6 KiB
Python
157 lines
4.6 KiB
Python
|
|
"""
|
|||
|
|
Agent 工具基类
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from abc import ABC, abstractmethod
|
|||
|
|
from typing import Any, Dict, Optional, Type
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
import logging
|
|||
|
|
import time
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ToolResult:
|
|||
|
|
"""工具执行结果"""
|
|||
|
|
success: bool
|
|||
|
|
data: Any = None
|
|||
|
|
error: Optional[str] = None
|
|||
|
|
duration_ms: int = 0
|
|||
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|||
|
|
|
|||
|
|
def to_dict(self) -> Dict[str, Any]:
|
|||
|
|
return {
|
|||
|
|
"success": self.success,
|
|||
|
|
"data": self.data,
|
|||
|
|
"error": self.error,
|
|||
|
|
"duration_ms": self.duration_ms,
|
|||
|
|
"metadata": self.metadata,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def to_string(self, max_length: int = 5000) -> str:
|
|||
|
|
"""转换为字符串(用于 LLM 输出)"""
|
|||
|
|
if not self.success:
|
|||
|
|
return f"Error: {self.error}"
|
|||
|
|
|
|||
|
|
if isinstance(self.data, str):
|
|||
|
|
result = self.data
|
|||
|
|
elif isinstance(self.data, (dict, list)):
|
|||
|
|
import json
|
|||
|
|
result = json.dumps(self.data, ensure_ascii=False, indent=2)
|
|||
|
|
else:
|
|||
|
|
result = str(self.data)
|
|||
|
|
|
|||
|
|
if len(result) > max_length:
|
|||
|
|
result = result[:max_length] + f"\n... (truncated, total {len(result)} chars)"
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AgentTool(ABC):
|
|||
|
|
"""
|
|||
|
|
Agent 工具基类
|
|||
|
|
所有工具需要继承此类并实现必要的方法
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self._call_count = 0
|
|||
|
|
self._total_duration_ms = 0
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
@abstractmethod
|
|||
|
|
def name(self) -> str:
|
|||
|
|
"""工具名称"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
@abstractmethod
|
|||
|
|
def description(self) -> str:
|
|||
|
|
"""工具描述(用于 Agent 理解工具功能)"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def args_schema(self) -> Optional[Type[BaseModel]]:
|
|||
|
|
"""参数 Schema(Pydantic 模型)"""
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
async def _execute(self, **kwargs) -> ToolResult:
|
|||
|
|
"""执行工具(子类实现)"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
async def execute(self, **kwargs) -> ToolResult:
|
|||
|
|
"""执行工具(带计时和日志)"""
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
logger.debug(f"Tool '{self.name}' executing with args: {kwargs}")
|
|||
|
|
result = await self._execute(**kwargs)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Tool '{self.name}' error: {e}", exc_info=True)
|
|||
|
|
result = ToolResult(
|
|||
|
|
success=False,
|
|||
|
|
error=str(e),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
duration_ms = int((time.time() - start_time) * 1000)
|
|||
|
|
result.duration_ms = duration_ms
|
|||
|
|
|
|||
|
|
self._call_count += 1
|
|||
|
|
self._total_duration_ms += duration_ms
|
|||
|
|
|
|||
|
|
logger.debug(f"Tool '{self.name}' completed in {duration_ms}ms, success={result.success}")
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
def get_langchain_tool(self):
|
|||
|
|
"""转换为 LangChain Tool"""
|
|||
|
|
from langchain.tools import Tool, StructuredTool
|
|||
|
|
import asyncio
|
|||
|
|
|
|||
|
|
def sync_wrapper(**kwargs):
|
|||
|
|
"""同步包装器"""
|
|||
|
|
loop = asyncio.get_event_loop()
|
|||
|
|
if loop.is_running():
|
|||
|
|
import concurrent.futures
|
|||
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|||
|
|
future = executor.submit(asyncio.run, self.execute(**kwargs))
|
|||
|
|
result = future.result()
|
|||
|
|
else:
|
|||
|
|
result = asyncio.run(self.execute(**kwargs))
|
|||
|
|
return result.to_string()
|
|||
|
|
|
|||
|
|
async def async_wrapper(**kwargs):
|
|||
|
|
"""异步包装器"""
|
|||
|
|
result = await self.execute(**kwargs)
|
|||
|
|
return result.to_string()
|
|||
|
|
|
|||
|
|
if self.args_schema:
|
|||
|
|
return StructuredTool(
|
|||
|
|
name=self.name,
|
|||
|
|
description=self.description,
|
|||
|
|
func=sync_wrapper,
|
|||
|
|
coroutine=async_wrapper,
|
|||
|
|
args_schema=self.args_schema,
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
return Tool(
|
|||
|
|
name=self.name,
|
|||
|
|
description=self.description,
|
|||
|
|
func=lambda x: sync_wrapper(query=x),
|
|||
|
|
coroutine=lambda x: async_wrapper(query=x),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def stats(self) -> Dict[str, Any]:
|
|||
|
|
"""工具使用统计"""
|
|||
|
|
return {
|
|||
|
|
"name": self.name,
|
|||
|
|
"call_count": self._call_count,
|
|||
|
|
"total_duration_ms": self._total_duration_ms,
|
|||
|
|
"avg_duration_ms": self._total_duration_ms // max(1, self._call_count),
|
|||
|
|
}
|
|||
|
|
|