285 lines
7.3 KiB
Python
285 lines
7.3 KiB
Python
"""
|
||
Agent 基类
|
||
定义 Agent 的基本接口和通用功能
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AgentType(Enum):
|
||
"""Agent 类型"""
|
||
ORCHESTRATOR = "orchestrator"
|
||
RECON = "recon"
|
||
ANALYSIS = "analysis"
|
||
VERIFICATION = "verification"
|
||
|
||
|
||
class AgentPattern(Enum):
|
||
"""Agent 运行模式"""
|
||
REACT = "react" # 反应式:思考-行动-观察循环
|
||
PLAN_AND_EXECUTE = "plan_execute" # 计划执行:先规划后执行
|
||
|
||
|
||
@dataclass
|
||
class AgentConfig:
|
||
"""Agent 配置"""
|
||
name: str
|
||
agent_type: AgentType
|
||
pattern: AgentPattern = AgentPattern.REACT
|
||
|
||
# LLM 配置
|
||
model: Optional[str] = None
|
||
temperature: float = 0.1
|
||
max_tokens: int = 4096
|
||
|
||
# 执行限制
|
||
max_iterations: int = 20
|
||
timeout_seconds: int = 600
|
||
|
||
# 工具配置
|
||
tools: List[str] = field(default_factory=list)
|
||
|
||
# 系统提示词
|
||
system_prompt: Optional[str] = None
|
||
|
||
# 元数据
|
||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
|
||
@dataclass
|
||
class AgentResult:
|
||
"""Agent 执行结果"""
|
||
success: bool
|
||
data: Any = None
|
||
error: Optional[str] = None
|
||
|
||
# 执行统计
|
||
iterations: int = 0
|
||
tool_calls: int = 0
|
||
tokens_used: int = 0
|
||
duration_ms: int = 0
|
||
|
||
# 中间结果
|
||
intermediate_steps: List[Dict[str, Any]] = field(default_factory=list)
|
||
|
||
# 元数据
|
||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"success": self.success,
|
||
"data": self.data,
|
||
"error": self.error,
|
||
"iterations": self.iterations,
|
||
"tool_calls": self.tool_calls,
|
||
"tokens_used": self.tokens_used,
|
||
"duration_ms": self.duration_ms,
|
||
"metadata": self.metadata,
|
||
}
|
||
|
||
|
||
class BaseAgent(ABC):
|
||
"""
|
||
Agent 基类
|
||
所有 Agent 需要继承此类并实现核心方法
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
config: AgentConfig,
|
||
llm_service,
|
||
tools: Dict[str, Any],
|
||
event_emitter=None,
|
||
):
|
||
"""
|
||
初始化 Agent
|
||
|
||
Args:
|
||
config: Agent 配置
|
||
llm_service: LLM 服务
|
||
tools: 可用工具字典
|
||
event_emitter: 事件发射器
|
||
"""
|
||
self.config = config
|
||
self.llm_service = llm_service
|
||
self.tools = tools
|
||
self.event_emitter = event_emitter
|
||
|
||
# 运行状态
|
||
self._iteration = 0
|
||
self._total_tokens = 0
|
||
self._tool_calls = 0
|
||
self._cancelled = False
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return self.config.name
|
||
|
||
@property
|
||
def agent_type(self) -> AgentType:
|
||
return self.config.agent_type
|
||
|
||
@abstractmethod
|
||
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||
"""
|
||
执行 Agent 任务
|
||
|
||
Args:
|
||
input_data: 输入数据
|
||
|
||
Returns:
|
||
Agent 执行结果
|
||
"""
|
||
pass
|
||
|
||
def cancel(self):
|
||
"""取消执行"""
|
||
self._cancelled = True
|
||
|
||
@property
|
||
def is_cancelled(self) -> bool:
|
||
return self._cancelled
|
||
|
||
async def emit_event(
|
||
self,
|
||
event_type: str,
|
||
message: str,
|
||
**kwargs
|
||
):
|
||
"""发射事件"""
|
||
if self.event_emitter:
|
||
from ..event_manager import AgentEventData
|
||
await self.event_emitter.emit(AgentEventData(
|
||
event_type=event_type,
|
||
message=message,
|
||
**kwargs
|
||
))
|
||
|
||
async def emit_thinking(self, message: str):
|
||
"""发射思考事件"""
|
||
await self.emit_event("thinking", f"[{self.name}] {message}")
|
||
|
||
async def emit_tool_call(self, tool_name: str, tool_input: Dict):
|
||
"""发射工具调用事件"""
|
||
await self.emit_event(
|
||
"tool_call",
|
||
f"[{self.name}] 调用工具: {tool_name}",
|
||
tool_name=tool_name,
|
||
tool_input=tool_input,
|
||
)
|
||
|
||
async def emit_tool_result(self, tool_name: str, result: str, duration_ms: int):
|
||
"""发射工具结果事件"""
|
||
await self.emit_event(
|
||
"tool_result",
|
||
f"[{self.name}] {tool_name} 完成 ({duration_ms}ms)",
|
||
tool_name=tool_name,
|
||
tool_duration_ms=duration_ms,
|
||
)
|
||
|
||
async def call_tool(self, tool_name: str, **kwargs) -> Any:
|
||
"""
|
||
调用工具
|
||
|
||
Args:
|
||
tool_name: 工具名称
|
||
**kwargs: 工具参数
|
||
|
||
Returns:
|
||
工具执行结果
|
||
"""
|
||
tool = self.tools.get(tool_name)
|
||
if not tool:
|
||
logger.warning(f"Tool not found: {tool_name}")
|
||
return None
|
||
|
||
self._tool_calls += 1
|
||
await self.emit_tool_call(tool_name, kwargs)
|
||
|
||
import time
|
||
start = time.time()
|
||
|
||
result = await tool.execute(**kwargs)
|
||
|
||
duration_ms = int((time.time() - start) * 1000)
|
||
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
|
||
|
||
return result
|
||
|
||
async def call_llm(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
tools: Optional[List[Dict]] = None,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
调用 LLM
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
tools: 可用工具描述
|
||
|
||
Returns:
|
||
LLM 响应
|
||
"""
|
||
self._iteration += 1
|
||
|
||
# 这里应该调用实际的 LLM 服务
|
||
# 使用 LangChain 或直接调用 API
|
||
try:
|
||
response = await self.llm_service.chat_completion(
|
||
messages=messages,
|
||
temperature=self.config.temperature,
|
||
max_tokens=self.config.max_tokens,
|
||
tools=tools,
|
||
)
|
||
|
||
if response.get("usage"):
|
||
self._total_tokens += response["usage"].get("total_tokens", 0)
|
||
|
||
return response
|
||
|
||
except Exception as e:
|
||
logger.error(f"LLM call failed: {e}")
|
||
raise
|
||
|
||
def get_tool_descriptions(self) -> List[Dict[str, Any]]:
|
||
"""获取工具描述(用于 LLM)"""
|
||
descriptions = []
|
||
|
||
for name, tool in self.tools.items():
|
||
if name.startswith("_"):
|
||
continue
|
||
|
||
desc = {
|
||
"type": "function",
|
||
"function": {
|
||
"name": name,
|
||
"description": tool.description,
|
||
}
|
||
}
|
||
|
||
# 添加参数 schema
|
||
if hasattr(tool, 'args_schema') and tool.args_schema:
|
||
desc["function"]["parameters"] = tool.args_schema.schema()
|
||
|
||
descriptions.append(desc)
|
||
|
||
return descriptions
|
||
|
||
def get_stats(self) -> Dict[str, Any]:
|
||
"""获取执行统计"""
|
||
return {
|
||
"agent": self.name,
|
||
"type": self.agent_type.value,
|
||
"iterations": self._iteration,
|
||
"tool_calls": self._tool_calls,
|
||
"tokens_used": self._total_tokens,
|
||
}
|
||
|