CodeReview/backend/app/services/agent/streaming/tool_stream.py

320 lines
10 KiB
Python

"""
工具调用流式处理器
展示工具调用的输入、执行过程和输出
"""
import asyncio
import time
import logging
from typing import Any, Dict, Optional, AsyncGenerator, List, Callable
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
logger = logging.getLogger(__name__)
class ToolCallState(str, Enum):
"""工具调用状态"""
PENDING = "pending" # 等待执行
RUNNING = "running" # 执行中
SUCCESS = "success" # 成功
ERROR = "error" # 错误
TIMEOUT = "timeout" # 超时
@dataclass
class ToolCallEvent:
"""工具调用事件"""
tool_name: str
state: ToolCallState
# 输入输出
input_params: Dict[str, Any] = field(default_factory=dict)
output_data: Optional[Any] = None
error_message: Optional[str] = None
# 时间
start_time: Optional[float] = None
end_time: Optional[float] = None
duration_ms: int = 0
# 元数据
call_id: Optional[str] = None
sequence: int = 0
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"tool_name": self.tool_name,
"state": self.state.value,
"input_params": self._truncate(self.input_params),
"output_data": self._truncate(self.output_data),
"error_message": self.error_message,
"duration_ms": self.duration_ms,
"call_id": self.call_id,
"sequence": self.sequence,
"timestamp": self.timestamp,
}
def _truncate(self, data: Any, max_length: int = 500) -> Any:
"""截断数据"""
if data is None:
return None
if isinstance(data, str):
return data[:max_length] + "..." if len(data) > max_length else data
elif isinstance(data, dict):
return {k: self._truncate(v, max_length // 2) for k, v in list(data.items())[:20]}
elif isinstance(data, list):
max_items = min(20, len(data))
return [self._truncate(item, max_length // max_items) for item in data[:max_items]]
else:
s = str(data)
return s[:max_length] + "..." if len(s) > max_length else s
class ToolStreamHandler:
"""
工具调用流式处理器
功能:
1. 跟踪工具调用状态
2. 记录输入参数
3. 流式输出执行过程
4. 记录输出和执行时间
"""
def __init__(
self,
on_event: Optional[Callable[[ToolCallEvent], None]] = None,
):
self.on_event = on_event
self._sequence = 0
self._active_calls: Dict[str, ToolCallEvent] = {}
self._history: List[ToolCallEvent] = []
def _next_sequence(self) -> int:
"""获取下一个序列号"""
self._sequence += 1
return self._sequence
def _generate_call_id(self) -> str:
"""生成调用 ID"""
import uuid
return str(uuid.uuid4())[:8]
async def emit_tool_start(
self,
tool_name: str,
input_params: Dict[str, Any],
call_id: Optional[str] = None,
) -> ToolCallEvent:
"""
发射工具开始事件
Args:
tool_name: 工具名称
input_params: 输入参数
call_id: 调用 ID
Returns:
工具调用事件
"""
call_id = call_id or self._generate_call_id()
event = ToolCallEvent(
tool_name=tool_name,
state=ToolCallState.RUNNING,
input_params=input_params,
start_time=time.time(),
call_id=call_id,
sequence=self._next_sequence(),
)
self._active_calls[call_id] = event
if self.on_event:
self.on_event(event)
return event
async def emit_tool_end(
self,
call_id: str,
output_data: Any,
is_error: bool = False,
error_message: Optional[str] = None,
) -> ToolCallEvent:
"""
发射工具结束事件
Args:
call_id: 调用 ID
output_data: 输出数据
is_error: 是否错误
error_message: 错误消息
Returns:
工具调用事件
"""
if call_id not in self._active_calls:
logger.warning(f"Unknown tool call: {call_id}")
return None
event = self._active_calls[call_id]
event.end_time = time.time()
event.duration_ms = int((event.end_time - event.start_time) * 1000) if event.start_time else 0
event.output_data = output_data
event.sequence = self._next_sequence()
if is_error:
event.state = ToolCallState.ERROR
event.error_message = error_message or str(output_data)
else:
event.state = ToolCallState.SUCCESS
# 移动到历史记录
del self._active_calls[call_id]
self._history.append(event)
if self.on_event:
self.on_event(event)
return event
async def emit_tool_timeout(self, call_id: str, timeout_seconds: int) -> ToolCallEvent:
"""发射工具超时事件"""
if call_id not in self._active_calls:
return None
event = self._active_calls[call_id]
event.end_time = time.time()
event.duration_ms = int((event.end_time - event.start_time) * 1000) if event.start_time else 0
event.state = ToolCallState.TIMEOUT
event.error_message = f"Tool execution timed out after {timeout_seconds}s"
event.sequence = self._next_sequence()
del self._active_calls[call_id]
self._history.append(event)
if self.on_event:
self.on_event(event)
return event
def wrap_tool(
self,
tool_func: Callable,
tool_name: str,
timeout: Optional[int] = None,
) -> Callable:
"""
包装工具函数以自动跟踪
Args:
tool_func: 工具函数
tool_name: 工具名称
timeout: 超时时间(秒)
Returns:
包装后的函数
"""
async def wrapped(*args, **kwargs):
call_id = self._generate_call_id()
# 发射开始事件
await self.emit_tool_start(
tool_name=tool_name,
input_params={"args": args, "kwargs": kwargs},
call_id=call_id,
)
try:
# 执行工具
if asyncio.iscoroutinefunction(tool_func):
if timeout:
result = await asyncio.wait_for(
tool_func(*args, **kwargs),
timeout=timeout,
)
else:
result = await tool_func(*args, **kwargs)
else:
if timeout:
result = await asyncio.wait_for(
asyncio.to_thread(tool_func, *args, **kwargs),
timeout=timeout,
)
else:
result = tool_func(*args, **kwargs)
# 发射结束事件
await self.emit_tool_end(call_id, result)
return result
except asyncio.TimeoutError:
await self.emit_tool_timeout(call_id, timeout or 0)
raise
except Exception as e:
await self.emit_tool_end(call_id, None, is_error=True, error_message=str(e))
raise
return wrapped
def get_active_calls(self) -> List[ToolCallEvent]:
"""获取活跃的调用"""
return list(self._active_calls.values())
def get_history(self, limit: int = 100) -> List[ToolCallEvent]:
"""获取历史记录"""
return self._history[-limit:]
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
total_calls = len(self._history)
success_calls = sum(1 for e in self._history if e.state == ToolCallState.SUCCESS)
error_calls = sum(1 for e in self._history if e.state == ToolCallState.ERROR)
timeout_calls = sum(1 for e in self._history if e.state == ToolCallState.TIMEOUT)
total_duration = sum(e.duration_ms for e in self._history)
avg_duration = total_duration / total_calls if total_calls > 0 else 0
# 按工具统计
tool_stats = {}
for event in self._history:
if event.tool_name not in tool_stats:
tool_stats[event.tool_name] = {
"calls": 0,
"success": 0,
"errors": 0,
"total_duration_ms": 0,
}
tool_stats[event.tool_name]["calls"] += 1
if event.state == ToolCallState.SUCCESS:
tool_stats[event.tool_name]["success"] += 1
elif event.state in [ToolCallState.ERROR, ToolCallState.TIMEOUT]:
tool_stats[event.tool_name]["errors"] += 1
tool_stats[event.tool_name]["total_duration_ms"] += event.duration_ms
return {
"total_calls": total_calls,
"success_calls": success_calls,
"error_calls": error_calls,
"timeout_calls": timeout_calls,
"success_rate": success_calls / total_calls if total_calls > 0 else 0,
"total_duration_ms": total_duration,
"avg_duration_ms": round(avg_duration, 2),
"active_calls": len(self._active_calls),
"by_tool": tool_stats,
}
def clear(self):
"""清空记录"""
self._active_calls.clear()
self._history.clear()
self._sequence = 0