""" Agent 工具基类 """ from abc import ABC, abstractmethod from typing import Any, Dict, Optional, Type from dataclasses import dataclass, field from pydantic import BaseModel import logging import time logger = logging.getLogger(__name__) @dataclass class ToolResult: """工具执行结果""" success: bool data: Any = None error: Optional[str] = None duration_ms: int = 0 metadata: Dict[str, Any] = field(default_factory=dict) def to_dict(self) -> Dict[str, Any]: return { "success": self.success, "data": self.data, "error": self.error, "duration_ms": self.duration_ms, "metadata": self.metadata, } def to_string(self, max_length: int = 5000) -> str: """转换为字符串(用于 LLM 输出)""" if not self.success: return f"Error: {self.error}" if isinstance(self.data, str): result = self.data elif isinstance(self.data, (dict, list)): import json result = json.dumps(self.data, ensure_ascii=False, indent=2) else: result = str(self.data) if len(result) > max_length: result = result[:max_length] + f"\n... (truncated, total {len(result)} chars)" return result class AgentTool(ABC): """ Agent 工具基类 所有工具需要继承此类并实现必要的方法 """ def __init__(self): self._call_count = 0 self._total_duration_ms = 0 @property @abstractmethod def name(self) -> str: """工具名称""" pass @property @abstractmethod def description(self) -> str: """工具描述(用于 Agent 理解工具功能)""" pass @property def args_schema(self) -> Optional[Type[BaseModel]]: """参数 Schema(Pydantic 模型)""" return None @abstractmethod async def _execute(self, **kwargs) -> ToolResult: """执行工具(子类实现)""" pass async def execute(self, **kwargs) -> ToolResult: """执行工具(带计时和日志)""" start_time = time.time() try: logger.debug(f"Tool '{self.name}' executing with args: {kwargs}") result = await self._execute(**kwargs) except Exception as e: logger.error(f"Tool '{self.name}' error: {e}", exc_info=True) result = ToolResult( success=False, error=str(e), ) duration_ms = int((time.time() - start_time) * 1000) result.duration_ms = duration_ms self._call_count += 1 self._total_duration_ms += duration_ms logger.debug(f"Tool '{self.name}' completed in {duration_ms}ms, success={result.success}") return result def get_langchain_tool(self): """转换为 LangChain Tool""" from langchain.tools import Tool, StructuredTool import asyncio def sync_wrapper(**kwargs): """同步包装器""" loop = asyncio.get_event_loop() if loop.is_running(): import concurrent.futures with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(asyncio.run, self.execute(**kwargs)) result = future.result() else: result = asyncio.run(self.execute(**kwargs)) return result.to_string() async def async_wrapper(**kwargs): """异步包装器""" result = await self.execute(**kwargs) return result.to_string() if self.args_schema: return StructuredTool( name=self.name, description=self.description, func=sync_wrapper, coroutine=async_wrapper, args_schema=self.args_schema, ) else: return Tool( name=self.name, description=self.description, func=lambda x: sync_wrapper(query=x), coroutine=lambda x: async_wrapper(query=x), ) @property def stats(self) -> Dict[str, Any]: """工具使用统计""" return { "name": self.name, "call_count": self._call_count, "total_duration_ms": self._total_duration_ms, "avg_duration_ms": self._total_duration_ms // max(1, self._call_count), }