CodeReview/backend/app/services/agent/core/persistence.py

414 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Agent 状态持久化模块
提供 Agent 状态的持久化和恢复功能:
- Agent 状态序列化和反序列化
- 检查点保存和恢复
- 消息历史持久化
- 执行记录持久化
"""
import json
import logging
import os
from datetime import datetime, timezone
from typing import Dict, Any, List, Optional
from pathlib import Path
from .state import AgentState, AgentStatus
from .registry import agent_registry
logger = logging.getLogger(__name__)
class AgentStatePersistence:
"""
Agent 状态持久化管理器
支持:
- 文件系统持久化
- 数据库持久化(可选)
- 检查点机制
"""
def __init__(
self,
persist_dir: str = "./agent_checkpoints",
use_database: bool = False,
db_session_factory=None,
):
"""
初始化持久化管理器
Args:
persist_dir: 持久化目录
use_database: 是否使用数据库持久化
db_session_factory: 数据库会话工厂
"""
self.persist_dir = Path(persist_dir)
self.use_database = use_database
self.db_session_factory = db_session_factory
# 确保目录存在
self.persist_dir.mkdir(parents=True, exist_ok=True)
# ============ 文件系统持久化 ============
def save_state(self, state: AgentState, checkpoint_name: Optional[str] = None) -> str:
"""
保存 Agent 状态到文件
Args:
state: Agent 状态
checkpoint_name: 检查点名称(可选)
Returns:
保存的文件路径
"""
# 生成文件名
if checkpoint_name:
filename = f"{state.agent_id}_{checkpoint_name}.json"
else:
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
filename = f"{state.agent_id}_{timestamp}.json"
filepath = self.persist_dir / filename
# 序列化状态
state_dict = self._serialize_state(state)
# 保存到文件
with open(filepath, "w", encoding="utf-8") as f:
json.dump(state_dict, f, ensure_ascii=False, indent=2)
logger.info(f"Saved agent state to {filepath}")
return str(filepath)
def load_state(self, filepath: str) -> Optional[AgentState]:
"""
从文件加载 Agent 状态
Args:
filepath: 文件路径
Returns:
Agent 状态,如果加载失败返回 None
"""
try:
with open(filepath, "r", encoding="utf-8") as f:
state_dict = json.load(f)
state = self._deserialize_state(state_dict)
logger.info(f"Loaded agent state from {filepath}")
return state
except Exception as e:
logger.error(f"Failed to load agent state from {filepath}: {e}")
return None
def load_latest_checkpoint(self, agent_id: str) -> Optional[AgentState]:
"""
加载指定 Agent 的最新检查点
Args:
agent_id: Agent ID
Returns:
Agent 状态
"""
# 查找所有匹配的检查点文件
pattern = f"{agent_id}_*.json"
checkpoints = list(self.persist_dir.glob(pattern))
if not checkpoints:
logger.warning(f"No checkpoints found for agent {agent_id}")
return None
# 按修改时间排序,取最新的
latest = max(checkpoints, key=lambda p: p.stat().st_mtime)
return self.load_state(str(latest))
def list_checkpoints(self, agent_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""
列出检查点
Args:
agent_id: Agent ID可选不指定则列出所有
Returns:
检查点信息列表
"""
if agent_id:
pattern = f"{agent_id}_*.json"
else:
pattern = "*.json"
checkpoints = []
for filepath in self.persist_dir.glob(pattern):
stat = filepath.stat()
checkpoints.append({
"filepath": str(filepath),
"filename": filepath.name,
"size_bytes": stat.st_size,
"created_at": datetime.fromtimestamp(stat.st_ctime, tz=timezone.utc).isoformat(),
"modified_at": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(),
})
# 按修改时间排序
checkpoints.sort(key=lambda x: x["modified_at"], reverse=True)
return checkpoints
def delete_checkpoint(self, filepath: str) -> bool:
"""
删除检查点
Args:
filepath: 文件路径
Returns:
是否删除成功
"""
try:
os.remove(filepath)
logger.info(f"Deleted checkpoint: {filepath}")
return True
except Exception as e:
logger.error(f"Failed to delete checkpoint {filepath}: {e}")
return False
def cleanup_old_checkpoints(
self,
agent_id: str,
keep_count: int = 5,
) -> int:
"""
清理旧的检查点,只保留最新的几个
Args:
agent_id: Agent ID
keep_count: 保留的检查点数量
Returns:
删除的检查点数量
"""
checkpoints = self.list_checkpoints(agent_id)
if len(checkpoints) <= keep_count:
return 0
# 删除旧的检查点
to_delete = checkpoints[keep_count:]
deleted = 0
for cp in to_delete:
if self.delete_checkpoint(cp["filepath"]):
deleted += 1
return deleted
# ============ 序列化/反序列化 ============
def _serialize_state(self, state: AgentState) -> Dict[str, Any]:
"""序列化 Agent 状态"""
return {
"version": "1.0",
"serialized_at": datetime.now(timezone.utc).isoformat(),
"state": state.model_dump(),
}
def _deserialize_state(self, data: Dict[str, Any]) -> AgentState:
"""反序列化 Agent 状态"""
version = data.get("version", "1.0")
state_data = data.get("state", data)
# 处理版本兼容性
if version == "1.0":
return AgentState(**state_data)
else:
logger.warning(f"Unknown state version: {version}, attempting to load anyway")
return AgentState(**state_data)
# ============ 数据库持久化 ============
async def save_state_to_db(
self,
state: AgentState,
task_id: str,
) -> bool:
"""
保存 Agent 状态到数据库
Args:
state: Agent 状态
task_id: 关联的任务 ID
Returns:
是否保存成功
"""
if not self.use_database or not self.db_session_factory:
logger.warning("Database persistence not configured")
return False
try:
async with self.db_session_factory() as session:
from app.models.agent_task import AgentCheckpoint
checkpoint = AgentCheckpoint(
task_id=task_id,
agent_id=state.agent_id,
agent_name=state.agent_name,
agent_type=state.agent_type,
state_data=state.model_dump_json(),
iteration=state.iteration,
status=state.status,
created_at=datetime.now(timezone.utc),
)
session.add(checkpoint)
await session.commit()
logger.debug(f"Saved agent state to database: {state.agent_id}")
return True
except Exception as e:
logger.error(f"Failed to save agent state to database: {e}")
return False
async def load_state_from_db(
self,
task_id: str,
agent_id: Optional[str] = None,
) -> Optional[AgentState]:
"""
从数据库加载 Agent 状态
Args:
task_id: 任务 ID
agent_id: Agent ID可选
Returns:
Agent 状态
"""
if not self.use_database or not self.db_session_factory:
logger.warning("Database persistence not configured")
return None
try:
async with self.db_session_factory() as session:
from sqlalchemy import select
from app.models.agent_task import AgentCheckpoint
query = select(AgentCheckpoint).where(
AgentCheckpoint.task_id == task_id
)
if agent_id:
query = query.where(AgentCheckpoint.agent_id == agent_id)
query = query.order_by(AgentCheckpoint.created_at.desc()).limit(1)
result = await session.execute(query)
checkpoint = result.scalar_one_or_none()
if checkpoint:
state_data = json.loads(checkpoint.state_data)
return AgentState(**state_data)
return None
except Exception as e:
logger.error(f"Failed to load agent state from database: {e}")
return None
class CheckpointManager:
"""
检查点管理器
提供自动检查点功能:
- 定期保存检查点
- 错误恢复
- 状态回滚
"""
def __init__(
self,
persistence: AgentStatePersistence,
auto_checkpoint_interval: int = 5, # 每 N 次迭代自动保存
):
self.persistence = persistence
self.auto_checkpoint_interval = auto_checkpoint_interval
self._last_checkpoint_iteration: Dict[str, int] = {}
def should_checkpoint(self, state: AgentState) -> bool:
"""
判断是否应该创建检查点
Args:
state: Agent 状态
Returns:
是否应该创建检查点
"""
last_iteration = self._last_checkpoint_iteration.get(state.agent_id, 0)
return state.iteration - last_iteration >= self.auto_checkpoint_interval
def create_checkpoint(
self,
state: AgentState,
checkpoint_name: Optional[str] = None,
) -> str:
"""
创建检查点
Args:
state: Agent 状态
checkpoint_name: 检查点名称
Returns:
检查点文件路径
"""
filepath = self.persistence.save_state(state, checkpoint_name)
self._last_checkpoint_iteration[state.agent_id] = state.iteration
return filepath
def auto_checkpoint(self, state: AgentState) -> Optional[str]:
"""
自动检查点(如果需要)
Args:
state: Agent 状态
Returns:
检查点文件路径,如果没有创建则返回 None
"""
if self.should_checkpoint(state):
return self.create_checkpoint(state)
return None
def restore_from_checkpoint(
self,
agent_id: str,
checkpoint_filepath: Optional[str] = None,
) -> Optional[AgentState]:
"""
从检查点恢复
Args:
agent_id: Agent ID
checkpoint_filepath: 检查点文件路径(可选,不指定则使用最新的)
Returns:
恢复的 Agent 状态
"""
if checkpoint_filepath:
return self.persistence.load_state(checkpoint_filepath)
else:
return self.persistence.load_latest_checkpoint(agent_id)
# 全局持久化管理器
agent_persistence = AgentStatePersistence()
checkpoint_manager = CheckpointManager(agent_persistence)