CodeReview/backend/app/services/agent/tools/base.py

157 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Agent 工具基类
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Type
from dataclasses import dataclass, field
from pydantic import BaseModel
import logging
import time
logger = logging.getLogger(__name__)
@dataclass
class ToolResult:
"""工具执行结果"""
success: bool
data: Any = None
error: Optional[str] = None
duration_ms: int = 0
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"success": self.success,
"data": self.data,
"error": self.error,
"duration_ms": self.duration_ms,
"metadata": self.metadata,
}
def to_string(self, max_length: int = 5000) -> str:
"""转换为字符串(用于 LLM 输出)"""
if not self.success:
return f"Error: {self.error}"
if isinstance(self.data, str):
result = self.data
elif isinstance(self.data, (dict, list)):
import json
result = json.dumps(self.data, ensure_ascii=False, indent=2)
else:
result = str(self.data)
if len(result) > max_length:
result = result[:max_length] + f"\n... (truncated, total {len(result)} chars)"
return result
class AgentTool(ABC):
"""
Agent 工具基类
所有工具需要继承此类并实现必要的方法
"""
def __init__(self):
self._call_count = 0
self._total_duration_ms = 0
@property
@abstractmethod
def name(self) -> str:
"""工具名称"""
pass
@property
@abstractmethod
def description(self) -> str:
"""工具描述(用于 Agent 理解工具功能)"""
pass
@property
def args_schema(self) -> Optional[Type[BaseModel]]:
"""参数 SchemaPydantic 模型)"""
return None
@abstractmethod
async def _execute(self, **kwargs) -> ToolResult:
"""执行工具(子类实现)"""
pass
async def execute(self, **kwargs) -> ToolResult:
"""执行工具(带计时和日志)"""
start_time = time.time()
try:
logger.debug(f"Tool '{self.name}' executing with args: {kwargs}")
result = await self._execute(**kwargs)
except Exception as e:
logger.error(f"Tool '{self.name}' error: {e}", exc_info=True)
result = ToolResult(
success=False,
error=str(e),
)
duration_ms = int((time.time() - start_time) * 1000)
result.duration_ms = duration_ms
self._call_count += 1
self._total_duration_ms += duration_ms
logger.debug(f"Tool '{self.name}' completed in {duration_ms}ms, success={result.success}")
return result
def get_langchain_tool(self):
"""转换为 LangChain Tool"""
from langchain.tools import Tool, StructuredTool
import asyncio
def sync_wrapper(**kwargs):
"""同步包装器"""
loop = asyncio.get_event_loop()
if loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, self.execute(**kwargs))
result = future.result()
else:
result = asyncio.run(self.execute(**kwargs))
return result.to_string()
async def async_wrapper(**kwargs):
"""异步包装器"""
result = await self.execute(**kwargs)
return result.to_string()
if self.args_schema:
return StructuredTool(
name=self.name,
description=self.description,
func=sync_wrapper,
coroutine=async_wrapper,
args_schema=self.args_schema,
)
else:
return Tool(
name=self.name,
description=self.description,
func=lambda x: sync_wrapper(query=x),
coroutine=lambda x: async_wrapper(query=x),
)
@property
def stats(self) -> Dict[str, Any]:
"""工具使用统计"""
return {
"name": self.name,
"call_count": self._call_count,
"total_duration_ms": self._total_duration_ms,
"avg_duration_ms": self._total_duration_ms // max(1, self._call_count),
}