Add common, refactor eval with eval_policy

This commit is contained in:
Cadene 2024-01-31 13:48:12 +00:00
parent 1e52499490
commit 5a5b190f70
10 changed files with 1590 additions and 64 deletions

View File

View File

View File

@ -0,0 +1,42 @@
from torchrl.envs.transforms import StepCounter, TransformedEnv
from lerobot.common.envs.simxarm import SimxarmEnv
def make_env(cfg):
assert cfg.env == "simxarm"
env = SimxarmEnv(
task=cfg.task,
from_pixels=cfg.from_pixels,
pixels_only=cfg.pixels_only,
image_size=cfg.image_size,
)
# limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length))
return env
# def make_env(env_name, frame_skip, device, is_test=False):
# env = GymEnv(
# env_name,
# frame_skip=frame_skip,
# from_pixels=True,
# pixels_only=False,
# device=device,
# )
# env = TransformedEnv(env)
# env.append_transform(NoopResetEnv(noops=30, random=True))
# if not is_test:
# env.append_transform(EndOfLifeTransform())
# env.append_transform(RewardClipping(-1, 1))
# env.append_transform(ToTensorImage())
# env.append_transform(GrayScale())
# env.append_transform(Resize(84, 84))
# env.append_transform(CatFrames(N=4, dim=-3))
# env.append_transform(RewardSum())
# env.append_transform(StepCounter(max_steps=4500))
# env.append_transform(DoubleToFloat())
# env.append_transform(VecNorm(in_keys=["pixels"]))
# return env

View File

@ -0,0 +1,183 @@
import importlib
from typing import Optional
import numpy as np
import torch
from tensordict import TensorDict
from torchrl.data.tensor_specs import (
BoundedTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import EnvBase
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
from lerobot.common.utils import set_seed
_has_gym = importlib.util.find_spec("gym") is not None
_has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym
class SimxarmEnv(EnvBase):
def __init__(
self,
task,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
seed=1337,
device="cpu",
):
super().__init__(device=device, batch_size=[])
self.task = task
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
if pixels_only:
assert from_pixels
if from_pixels:
assert image_size
if not _has_simxarm:
raise ImportError("Cannot import simxarm.")
if not _has_gym:
raise ImportError("Cannot import gym.")
import gym
from gym.wrappers import TimeLimit
from simxarm import TASKS
if self.task not in TASKS:
raise ValueError(
f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}"
)
self._env = TASKS[self.task]["env"]()
self._env = TimeLimit(self._env, TASKS[self.task]["episode_length"])
MAX_NUM_ACTIONS = 4
num_actions = len(TASKS[self.task]["action_space"])
self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
self._action_padding = np.zeros(
(MAX_NUM_ACTIONS - num_actions), dtype=np.float32
)
if "w" not in TASKS[self.task]["action_space"]:
self._action_padding[-1] = 1.0
self._make_spec()
self.set_seed(seed)
def render(self, mode="rgb_array", width=384, height=384):
return self._env.render(mode, width=width, height=height)
def _format_raw_obs(self, raw_obs):
if self.from_pixels:
camera = self.render(
mode="rgb_array", width=self.image_size, height=self.image_size
)
camera = camera.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
camera = torch.tensor(camera.copy(), dtype=torch.uint8)
obs = {"camera": camera}
if not self.pixels_only:
obs["robot_state"] = torch.tensor(
self._env.robot_state, dtype=torch.float32
)
else:
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
obs = TensorDict(obs, batch_size=[])
return obs
def _reset(self, tensordict: Optional[TensorDict] = None):
td = tensordict
if td is None or td.is_empty():
raw_obs = self._env.reset()
td = TensorDict(
{
"observation": self._format_raw_obs(raw_obs),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
else:
raise NotImplementedError()
return td
def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# step expects shape=(4,) so we pad if necessary
action = np.concatenate([action, self._action_padding])
# TODO(rcadene): add info["is_success"] and info["success"] ?
raw_obs, reward, done, info = self._env.step(action)
td = TensorDict(
{
"observation": self._format_raw_obs(raw_obs),
"reward": torch.tensor([reward], dtype=torch.float32),
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([info["success"]], dtype=torch.bool),
},
batch_size=[],
)
return td
def _make_spec(self):
obs = {}
if self.from_pixels:
obs["camera"] = BoundedTensorSpec(
low=0,
high=255,
shape=(3, self.image_size, self.image_size),
dtype=torch.uint8,
device=self.device,
)
if not self.pixels_only:
obs["robot_state"] = UnboundedContinuousTensorSpec(
shape=(len(self._env.robot_state),),
dtype=torch.float32,
device=self.device,
)
else:
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
obs["state"] = UnboundedContinuousTensorSpec(
shape=self._env.observation_space["observation"].shape,
dtype=torch.float32,
device=self.device,
)
self.observation_spec = CompositeSpec({"observation": obs})
self.action_spec = _gym_to_torchrl_spec_transform(
self._action_space,
device=self.device,
)
self.reward_spec = UnboundedContinuousTensorSpec(
shape=(1,),
dtype=torch.float32,
device=self.device,
)
self.done_spec = DiscreteTensorSpec(
2,
shape=(1,),
dtype=torch.bool,
device=self.device,
)
self.success_spec = DiscreteTensorSpec(
2,
shape=(1,),
dtype=torch.bool,
device=self.device,
)
def _set_seed(self, seed: Optional[int]):
set_seed(seed)
self._env.seed(seed)

450
lerobot/common/tdmpc.py Normal file
View File

@ -0,0 +1,450 @@
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
import lerobot.common.tdmpc_helper as h
class TOLD(nn.Module):
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
def __init__(self, cfg):
super().__init__()
action_dim = 4
self.cfg = cfg
self._encoder = h.enc(cfg)
self._dynamics = h.dynamics(
cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim
)
self._reward = h.mlp(cfg.latent_dim + action_dim, cfg.mlp_dim, 1)
self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, action_dim)
self._Qs = nn.ModuleList([h.q(cfg) for _ in range(cfg.num_q)])
self._V = h.v(cfg)
self.apply(h.orthogonal_init)
for m in [self._reward, *self._Qs]:
m[-1].weight.data.fill_(0)
m[-1].bias.data.fill_(0)
def track_q_grad(self, enable=True):
"""Utility function. Enables/disables gradient tracking of Q-networks."""
for m in self._Qs:
h.set_requires_grad(m, enable)
def track_v_grad(self, enable=True):
"""Utility function. Enables/disables gradient tracking of Q-networks."""
if hasattr(self, "_V"):
h.set_requires_grad(self._V, enable)
def encode(self, obs):
"""Encodes an observation into its latent representation."""
out = self._encoder(obs)
if isinstance(obs, dict):
# fusion
out = torch.stack([v for k, v in out.items()]).mean(dim=0)
return out
def next(self, z, a):
"""Predicts next latent state (d) and single-step reward (R)."""
x = torch.cat([z, a], dim=-1)
return self._dynamics(x), self._reward(x)
def pi(self, z, std=0):
"""Samples an action from the learned policy (pi)."""
mu = torch.tanh(self._pi(z))
if std > 0:
std = torch.ones_like(mu) * std
return h.TruncatedNormal(mu, std).sample(clip=0.3)
return mu
def V(self, z):
"""Predict state value (V)."""
return self._V(z)
def Q(self, z, a, return_type):
"""Predict state-action value (Q)."""
assert return_type in {"min", "avg", "all"}
x = torch.cat([z, a], dim=-1)
if return_type == "all":
return torch.stack(list(q(x) for q in self._Qs), dim=0)
idxs = np.random.choice(self.cfg.num_q, 2, replace=False)
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x)
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
class TDMPC(nn.Module):
"""Implementation of TD-MPC learning + inference."""
def __init__(self, cfg):
super().__init__()
self.action_dim = 4
self.cfg = cfg
self.device = torch.device("cuda")
self.std = h.linear_schedule(cfg.std_schedule, 0)
self.model = TOLD(cfg).cuda()
self.model_target = deepcopy(self.model)
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.model.eval()
self.model_target.eval()
self.batch_size = cfg.batch_size
# TODO(rcadene): clean
self.step = 100000
def state_dict(self):
"""Retrieve state dict of TOLD model, including slow-moving target network."""
return {
"model": self.model.state_dict(),
"model_target": self.model_target.state_dict(),
}
def save(self, fp):
"""Save state dict of TOLD model to filepath."""
torch.save(self.state_dict(), fp)
def load(self, fp):
"""Load a saved state dict from filepath into current agent."""
d = torch.load(fp)
self.model.load_state_dict(d["model"])
self.model_target.load_state_dict(d["model_target"])
@torch.no_grad()
def forward(self, observation, step_count):
t0 = step_count.item() == 0
obs = {
"rgb": observation["camera"],
"state": observation["robot_state"],
}
return self.act(obs, t0=t0, step=self.step)
@torch.no_grad()
def act(self, obs, t0=False, step=None):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
if isinstance(obs, dict):
obs = {
k: torch.tensor(o, dtype=torch.float32, device=self.device).unsqueeze(0)
for k, o in obs.items()
}
else:
obs = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(
0
)
z = self.model.encode(obs)
if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step)
else:
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
return a.cpu()
@torch.no_grad()
def estimate_value(self, z, actions, horizon):
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
G, discount = 0, 1
for t in range(horizon):
if self.cfg.uncertainty_cost > 0:
G -= (
discount
* self.cfg.uncertainty_cost
* self.model.Q(z, actions[t], return_type="all").std(dim=0)
)
z, reward = self.model.next(z, actions[t])
G += discount * reward
discount *= self.cfg.discount
pi = self.model.pi(z, self.cfg.min_std)
G += discount * self.model.Q(z, pi, return_type="min")
if self.cfg.uncertainty_cost > 0:
G -= (
discount
* self.cfg.uncertainty_cost
* self.model.Q(z, pi, return_type="all").std(dim=0)
)
return G
@torch.no_grad()
def plan(self, z, step=None, t0=True):
"""
Plan next action using TD-MPC inference.
z: latent state.
step: current time step. determines e.g. planning horizon.
t0: whether current step is the first step of an episode.
"""
# during eval: eval_mode: uniform sampling and action noise is disabled during evaluation.
assert step is not None
# Seed steps
if step < self.cfg.seed_steps and self.model.training:
return torch.empty(
self.action_dim, dtype=torch.float32, device=self.device
).uniform_(-1, 1)
# Sample policy trajectories
horizon = int(
min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step))
)
num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
if num_pi_trajs > 0:
pi_actions = torch.empty(
horizon, num_pi_trajs, self.action_dim, device=self.device
)
_z = z.repeat(num_pi_trajs, 1)
for t in range(horizon):
pi_actions[t] = self.model.pi(_z, self.cfg.min_std)
_z, _ = self.model.next(_z, pi_actions[t])
# Initialize state and parameters
z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1)
mean = torch.zeros(horizon, self.action_dim, device=self.device)
std = self.cfg.max_std * torch.ones(
horizon, self.action_dim, device=self.device
)
if not t0 and hasattr(self, "_prev_mean"):
mean[:-1] = self._prev_mean[1:]
# Iterate CEM
for i in range(self.cfg.iterations):
actions = torch.clamp(
mean.unsqueeze(1)
+ std.unsqueeze(1)
* torch.randn(
horizon, self.cfg.num_samples, self.action_dim, device=std.device
),
-1,
1,
)
if num_pi_trajs > 0:
actions = torch.cat([actions, pi_actions], dim=1)
# Compute elite actions
value = self.estimate_value(z, actions, horizon).nan_to_num_(0)
elite_idxs = torch.topk(
value.squeeze(1), self.cfg.num_elites, dim=0
).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
# Update parameters
max_value = elite_value.max(0)[0]
score = torch.exp(self.cfg.temperature * (elite_value - max_value))
score /= score.sum(0)
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (
score.sum(0) + 1e-9
)
_std = torch.sqrt(
torch.sum(
score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2,
dim=1,
)
/ (score.sum(0) + 1e-9)
)
_std = _std.clamp_(self.std, self.cfg.max_std)
mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std
# Outputs
score = score.squeeze(1).cpu().numpy()
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
self._prev_mean = mean
mean, std = actions[0], _std[0]
a = mean
if self.model.training:
a += std * torch.randn(self.action_dim, device=std.device)
return torch.clamp(a, -1, 1)
def update_pi(self, zs, acts=None):
"""Update policy using a sequence of latent states."""
self.pi_optim.zero_grad(set_to_none=True)
self.model.track_q_grad(False)
self.model.track_v_grad(False)
info = {}
# Advantage Weighted Regression
assert acts is not None
vs = self.model.V(zs)
qs = self.model_target.Q(zs, acts, return_type="min")
adv = qs - vs
exp_a = torch.exp(adv * self.cfg.A_scaling)
exp_a = torch.clamp(exp_a, max=100.0)
log_probs = h.gaussian_logprob(self.model.pi(zs) - acts, 0)
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
pi_loss = -((exp_a * log_probs).mean(dim=(1, 2)) * rho).mean()
info["adv"] = adv[0]
pi_loss.backward()
torch.nn.utils.clip_grad_norm_(
self.model._pi.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
self.pi_optim.step()
self.model.track_q_grad(True)
self.model.track_v_grad(True)
info["pi_loss"] = pi_loss.item()
return pi_loss.item(), info
@torch.no_grad()
def _td_target(self, next_z, reward, mask):
"""Compute the TD-target from a reward and the observation at the following time step."""
next_v = self.model.V(next_z)
td_target = reward + self.cfg.discount * mask * next_v
return td_target
def update(self, replay_buffer, step, demo_buffer=None):
"""Main update function. Corresponds to one iteration of the model learning."""
if demo_buffer is not None:
# Update oversampling ratio
self.demo_batch_size = int(
h.linear_schedule(self.cfg.demo_schedule, step) * self.batch_size
)
replay_buffer.cfg.batch_size = self.batch_size - self.demo_batch_size
demo_buffer.cfg.batch_size = self.demo_batch_size
else:
self.demo_batch_size = 0
# Sample from interaction dataset
obs, next_obses, action, reward, mask, done, idxs, weights = (
replay_buffer.sample()
)
# Sample from demonstration dataset
if self.demo_batch_size > 0:
(
demo_obs,
demo_next_obses,
demo_action,
demo_reward,
demo_mask,
demo_done,
demo_idxs,
demo_weights,
) = demo_buffer.sample()
if isinstance(obs, dict):
obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
next_obses = {
k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1)
for k in next_obses
}
else:
obs = torch.cat([obs, demo_obs])
next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
action = torch.cat([action, demo_action], dim=1)
reward = torch.cat([reward, demo_reward], dim=1)
mask = torch.cat([mask, demo_mask], dim=1)
done = torch.cat([done, demo_done], dim=1)
idxs = torch.cat([idxs, demo_idxs])
weights = torch.cat([weights, demo_weights])
horizon = self.cfg.horizon
loss_mask = torch.ones_like(mask, device=self.device)
for t in range(1, horizon):
loss_mask[t] = loss_mask[t - 1] * (~done[t - 1])
self.optim.zero_grad(set_to_none=True)
self.std = h.linear_schedule(self.cfg.std_schedule, step)
self.model.train()
# Compute targets
with torch.no_grad():
next_z = self.model.encode(next_obses)
z_targets = self.model_target.encode(next_obses)
td_targets = self._td_target(next_z, reward, mask)
# Latent rollout
zs = torch.empty(
horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device
)
reward_preds = torch.empty_like(reward, device=self.device)
assert reward.shape[0] == horizon
z = self.model.encode(obs)
zs[0] = z
value_info = {"Q": 0.0, "V": 0.0}
for t in range(horizon):
z, reward_pred = self.model.next(z, action[t])
zs[t + 1] = z
reward_preds[t] = reward_pred
with torch.no_grad():
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
# Predictions
qs = self.model.Q(zs[:-1], action, return_type="all")
value_info["Q"] = qs.mean().item()
v = self.model.V(zs[:-1])
value_info["V"] = v.mean().item()
# Losses
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(
-1, 1, 1
)
consistency_loss = (
rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask
).sum(dim=0)
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
q_value_loss, priority_loss = 0, 0
for q in range(self.cfg.num_q):
q_value_loss += (rho * h.mse(qs[q], td_targets) * loss_mask).sum(dim=0)
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
self.expectile = h.linear_schedule(self.cfg.expectile, step)
v_value_loss = (
rho * h.l2_expectile(v_target - v, expectile=self.expectile) * loss_mask
).sum(dim=0)
total_loss = (
self.cfg.consistency_coef * consistency_loss
+ self.cfg.reward_coef * reward_loss
+ self.cfg.value_coef * q_value_loss
+ self.cfg.value_coef * v_value_loss
)
weighted_loss = (total_loss.squeeze(1) * weights).mean()
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
weighted_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
)
self.optim.step()
if self.cfg.per:
# Update priorities
priorities = priority_loss.clamp(max=1e4).detach()
replay_buffer.update_priorities(
idxs[: replay_buffer.cfg.batch_size],
priorities[: replay_buffer.cfg.batch_size],
)
if self.demo_batch_size > 0:
demo_buffer.update_priorities(
demo_idxs, priorities[replay_buffer.cfg.batch_size :]
)
# Update policy + target network
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
if step % self.cfg.update_freq == 0:
h.ema(self.model._encoder, self.model_target._encoder, self.cfg.tau)
h.ema(self.model._Qs, self.model_target._Qs, self.cfg.tau)
self.model.eval()
metrics = {
"consistency_loss": float(consistency_loss.mean().item()),
"reward_loss": float(reward_loss.mean().item()),
"Q_value_loss": float(q_value_loss.mean().item()),
"V_value_loss": float(v_value_loss.mean().item()),
"total_loss": float(total_loss.mean().item()),
"weighted_loss": float(weighted_loss.mean().item()),
"grad_norm": float(grad_norm),
}
for key in ["demo_batch_size", "expectile"]:
if hasattr(self, key):
metrics[key] = getattr(self, key)
metrics.update(value_info)
metrics.update(pi_update_info)
return metrics

View File

@ -0,0 +1,829 @@
import os
import pickle
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions as pyd
from torch.distributions.utils import _standard_normal
__REDUCE__ = lambda b: "mean" if b else "none"
def l1(pred, target, reduce=False):
"""Computes the L1-loss between predictions and targets."""
return F.l1_loss(pred, target, reduction=__REDUCE__(reduce))
def mse(pred, target, reduce=False):
"""Computes the MSE loss between predictions and targets."""
return F.mse_loss(pred, target, reduction=__REDUCE__(reduce))
def l2_expectile(diff, expectile=0.7, reduce=False):
weight = torch.where(diff > 0, expectile, (1 - expectile))
loss = weight * (diff**2)
reduction = __REDUCE__(reduce)
if reduction == "mean":
return torch.mean(loss)
elif reduction == "sum":
return torch.sum(loss)
return loss
def _get_out_shape(in_shape, layers):
"""Utility function. Returns the output shape of a network for a given input shape."""
x = torch.randn(*in_shape).unsqueeze(0)
return (
(nn.Sequential(*layers) if isinstance(layers, list) else layers)(x)
.squeeze(0)
.shape
)
def gaussian_logprob(eps, log_std):
"""Compute Gaussian log probability."""
residual = (-0.5 * eps.pow(2) - log_std).sum(-1, keepdim=True)
return residual - 0.5 * np.log(2 * np.pi) * eps.size(-1)
def squash(mu, pi, log_pi):
"""Apply squashing function."""
mu = torch.tanh(mu)
pi = torch.tanh(pi)
log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
return mu, pi, log_pi
def orthogonal_init(m):
"""Orthogonal layer initialization."""
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
gain = nn.init.calculate_gain("relu")
nn.init.orthogonal_(m.weight.data, gain)
if m.bias is not None:
nn.init.zeros_(m.bias)
def ema(m, m_target, tau):
"""Update slow-moving average of online network (target network) at rate tau."""
with torch.no_grad():
for p, p_target in zip(m.parameters(), m_target.parameters()):
p_target.data.lerp_(p.data, tau)
def set_requires_grad(net, value):
"""Enable/disable gradients for a given (sub)network."""
for param in net.parameters():
param.requires_grad_(value)
class TruncatedNormal(pyd.Normal):
"""Utility class implementing the truncated normal distribution."""
def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
super().__init__(loc, scale, validate_args=False)
self.low = low
self.high = high
self.eps = eps
def _clamp(self, x):
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
x = x - x.detach() + clamped_x.detach()
return x
def sample(self, clip=None, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
eps *= self.scale
if clip is not None:
eps = torch.clamp(eps, -clip, clip)
x = self.loc + eps
return self._clamp(x)
class NormalizeImg(nn.Module):
"""Normalizes pixel observations to [0,1) range."""
def __init__(self):
super().__init__()
def forward(self, x):
return x.div(255.0)
class Flatten(nn.Module):
"""Flattens its input to a (batched) vector."""
def __init__(self):
super().__init__()
def forward(self, x):
return x.view(x.size(0), -1)
def enc(cfg):
obs_shape = {
"rgb": (3, cfg.img_size, cfg.img_size),
"state": (4,),
}
"""Returns a TOLD encoder."""
pixels_enc_layers, state_enc_layers = None, None
if cfg.modality in {"pixels", "all"}:
C = int(3 * cfg.frame_stack)
pixels_enc_layers = [
NormalizeImg(),
nn.Conv2d(C, cfg.num_channels, 7, stride=2),
nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2),
nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2),
nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2),
nn.ReLU(),
]
out_shape = _get_out_shape((C, cfg.img_size, cfg.img_size), pixels_enc_layers)
pixels_enc_layers.extend(
[
Flatten(),
nn.Linear(np.prod(out_shape), cfg.latent_dim),
nn.LayerNorm(cfg.latent_dim),
nn.Sigmoid(),
]
)
if cfg.modality == "pixels":
return ConvExt(nn.Sequential(*pixels_enc_layers))
if cfg.modality in {"state", "all"}:
state_dim = obs_shape[0] if cfg.modality == "state" else obs_shape["state"][0]
state_enc_layers = [
nn.Linear(state_dim, cfg.enc_dim),
nn.ELU(),
nn.Linear(cfg.enc_dim, cfg.latent_dim),
nn.LayerNorm(cfg.latent_dim),
nn.Sigmoid(),
]
if cfg.modality == "state":
return nn.Sequential(*state_enc_layers)
else:
raise NotImplementedError
encoders = {}
for k in obs_shape:
if k == "state":
encoders[k] = nn.Sequential(*state_enc_layers)
elif k.endswith("rgb"):
encoders[k] = ConvExt(nn.Sequential(*pixels_enc_layers))
else:
raise NotImplementedError
return Multiplexer(nn.ModuleDict(encoders))
def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
"""Returns an MLP."""
if isinstance(mlp_dim, int):
mlp_dim = [mlp_dim, mlp_dim]
return nn.Sequential(
nn.Linear(in_dim, mlp_dim[0]),
nn.LayerNorm(mlp_dim[0]),
act_fn,
nn.Linear(mlp_dim[0], mlp_dim[1]),
nn.LayerNorm(mlp_dim[1]),
act_fn,
nn.Linear(mlp_dim[1], out_dim),
)
def dynamics(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
"""Returns a dynamics network."""
return nn.Sequential(
mlp(in_dim, mlp_dim, out_dim, act_fn),
nn.LayerNorm(out_dim),
nn.Sigmoid(),
)
def q(cfg):
action_dim = 4
"""Returns a Q-function that uses Layer Normalization."""
return nn.Sequential(
nn.Linear(cfg.latent_dim + action_dim, cfg.mlp_dim),
nn.LayerNorm(cfg.mlp_dim),
nn.Tanh(),
nn.Linear(cfg.mlp_dim, cfg.mlp_dim),
nn.ELU(),
nn.Linear(cfg.mlp_dim, 1),
)
def v(cfg):
"""Returns a state value function that uses Layer Normalization."""
return nn.Sequential(
nn.Linear(cfg.latent_dim, cfg.mlp_dim),
nn.LayerNorm(cfg.mlp_dim),
nn.Tanh(),
nn.Linear(cfg.mlp_dim, cfg.mlp_dim),
nn.ELU(),
nn.Linear(cfg.mlp_dim, 1),
)
def aug(cfg):
obs_shape = {
"rgb": (3, cfg.img_size, cfg.img_size),
"state": (4,),
}
"""Multiplex augmentation"""
if cfg.modality == "state":
return nn.Identity()
elif cfg.modality == "pixels":
return RandomShiftsAug(cfg)
else:
augs = {}
for k in obs_shape:
if k == "state":
augs[k] = nn.Identity()
elif k.endswith("rgb"):
augs[k] = RandomShiftsAug(cfg)
else:
raise NotImplementedError
return Multiplexer(nn.ModuleDict(augs))
class ConvExt(nn.Module):
"""Auxiliary conv net accommodating high-dim input"""
def __init__(self, conv):
super().__init__()
self.conv = conv
def forward(self, x):
if x.ndim > 4:
batch_shape = x.shape[:-3]
out = self.conv(x.view(-1, *x.shape[-3:]))
out = out.view(*batch_shape, *out.shape[1:])
else:
out = self.conv(x)
return out
class Multiplexer(nn.Module):
"""Model multiplexer"""
def __init__(self, choices):
super().__init__()
self.choices = choices
def forward(self, x, key=None):
if isinstance(x, dict):
if key is not None:
return self.choices[key](x)
return {k: self.choices[k](_x) for k, _x in x.items()}
return self.choices(x)
class RandomShiftsAug(nn.Module):
"""
Random shift image augmentation.
Adapted from https://github.com/facebookresearch/drqv2
"""
def __init__(self, cfg):
super().__init__()
assert cfg.modality in {"pixels", "all"}
self.pad = int(cfg.img_size / 21)
def forward(self, x):
n, c, h, w = x.size()
assert h == w
padding = tuple([self.pad] * 4)
x = F.pad(x, padding, "replicate")
eps = 1.0 / (h + 2 * self.pad)
arange = torch.linspace(
-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype
)[:h]
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
shift = torch.randint(
0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype
)
shift *= 2.0 / (h + 2 * self.pad)
grid = base_grid + shift
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
class Episode(object):
"""Storage object for a single episode."""
def __init__(self, cfg, init_obs):
action_dim = 4
self.cfg = cfg
self.device = torch.device(cfg.buffer_device)
if cfg.modality in {"pixels", "state"}:
dtype = torch.float32 if cfg.modality == "state" else torch.uint8
self.obses = torch.empty(
(cfg.episode_length + 1, *init_obs.shape),
dtype=dtype,
device=self.device,
)
self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device)
elif cfg.modality == "all":
self.obses = {}
for k, v in init_obs.items():
assert k in {"rgb", "state"}
dtype = torch.float32 if k == "state" else torch.uint8
self.obses[k] = torch.empty(
(cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device
)
self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device)
else:
raise ValueError
self.actions = torch.empty(
(cfg.episode_length, action_dim), dtype=torch.float32, device=self.device
)
self.rewards = torch.empty(
(cfg.episode_length,), dtype=torch.float32, device=self.device
)
self.dones = torch.empty(
(cfg.episode_length,), dtype=torch.bool, device=self.device
)
self.masks = torch.empty(
(cfg.episode_length,), dtype=torch.float32, device=self.device
)
self.cumulative_reward = 0
self.done = False
self.success = False
self._idx = 0
def __len__(self):
return self._idx
@classmethod
def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None):
"""Constructs an episode from a trajectory."""
if cfg.modality in {"pixels", "state"}:
episode = cls(cfg, obses[0])
episode.obses[1:] = torch.tensor(
obses[1:], dtype=episode.obses.dtype, device=episode.device
)
elif cfg.modality == "all":
episode = cls(cfg, {k: v[0] for k, v in obses.items()})
for k, v in obses.items():
episode.obses[k][1:] = torch.tensor(
obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device
)
else:
raise NotImplementedError
episode.actions = torch.tensor(
actions, dtype=episode.actions.dtype, device=episode.device
)
episode.rewards = torch.tensor(
rewards, dtype=episode.rewards.dtype, device=episode.device
)
episode.dones = (
torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device)
if dones is not None
else torch.zeros_like(episode.dones)
)
episode.masks = (
torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device)
if masks is not None
else torch.ones_like(episode.masks)
)
episode.cumulative_reward = torch.sum(episode.rewards)
episode.done = True
episode._idx = cfg.episode_length
return episode
@property
def first(self):
return len(self) == 0
def __add__(self, transition):
self.add(*transition)
return self
def add(self, obs, action, reward, done, mask=1.0, success=False):
"""Add a transition into the episode."""
if isinstance(obs, dict):
for k, v in obs.items():
self.obses[k][self._idx + 1] = torch.tensor(
v, dtype=self.obses[k].dtype, device=self.obses[k].device
)
else:
self.obses[self._idx + 1] = torch.tensor(
obs, dtype=self.obses.dtype, device=self.obses.device
)
self.actions[self._idx] = action
self.rewards[self._idx] = reward
self.dones[self._idx] = done
self.masks[self._idx] = mask
self.cumulative_reward += reward
self.done = done
self.success = self.success or success
self._idx += 1
class ReplayBuffer:
"""
Storage and sampling functionality.
"""
def __init__(self, cfg, dataset=None):
action_dim = 4
obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (4,)}
self.cfg = cfg
self.device = torch.device(cfg.buffer_device)
print("Replay buffer device: ", self.device)
if dataset is not None:
self.capacity = max(dataset["rewards"].shape[0], cfg.max_buffer_size)
else:
self.capacity = min(cfg.train_steps, cfg.max_buffer_size)
if cfg.modality in {"pixels", "state"}:
dtype = torch.float32 if cfg.modality == "state" else torch.uint8
# Note self.obs_shape always has single frame, which is different from cfg.obs_shape
self.obs_shape = (
obs_shape if cfg.modality == "state" else (3, *obs_shape[-2:])
)
self._obs = torch.zeros(
(self.capacity + cfg.horizon - 1, *self.obs_shape),
dtype=dtype,
device=self.device,
)
self._next_obs = torch.zeros(
(self.capacity + cfg.horizon - 1, *self.obs_shape),
dtype=dtype,
device=self.device,
)
elif cfg.modality == "all":
self.obs_shape = {}
self._obs, self._next_obs = {}, {}
for k, v in obs_shape.items():
assert k in {"rgb", "state"}
dtype = torch.float32 if k == "state" else torch.uint8
self.obs_shape[k] = v if k == "state" else (3, *v[-2:])
self._obs[k] = torch.zeros(
(self.capacity + cfg.horizon - 1, *self.obs_shape[k]),
dtype=dtype,
device=self.device,
)
self._next_obs[k] = self._obs[k].clone()
else:
raise ValueError
self._action = torch.zeros(
(self.capacity + cfg.horizon - 1, action_dim),
dtype=torch.float32,
device=self.device,
)
self._reward = torch.zeros(
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
)
self._mask = torch.zeros(
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
)
self._done = torch.zeros(
(self.capacity + cfg.horizon - 1,), dtype=torch.bool, device=self.device
)
self._priorities = torch.ones(
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
)
self._eps = 1e-6
self._full = False
self.idx = 0
if dataset is not None:
self.init_from_offline_dataset(dataset)
self._aug = aug(cfg)
def init_from_offline_dataset(self, dataset):
"""Initialize the replay buffer from an offline dataset."""
assert self.idx == 0 and not self._full
n_transitions = int(len(dataset["rewards"]) * self.cfg.data_first_percent)
def copy_data(dst, src, n):
assert isinstance(dst, dict) == isinstance(src, dict)
if isinstance(dst, dict):
for k in dst:
copy_data(dst[k], src[k], n)
else:
dst[:n] = torch.from_numpy(src[:n])
copy_data(self._obs, dataset["observations"], n_transitions)
copy_data(self._next_obs, dataset["next_observations"], n_transitions)
copy_data(self._action, dataset["actions"], n_transitions)
copy_data(self._reward, dataset["rewards"], n_transitions)
copy_data(self._mask, dataset["masks"], n_transitions)
copy_data(self._done, dataset["dones"], n_transitions)
self.idx = (self.idx + n_transitions) % self.capacity
self._full = n_transitions >= self.capacity
def __add__(self, episode: Episode):
self.add(episode)
return self
def add(self, episode: Episode):
"""Add an episode to the replay buffer."""
if self.idx + len(episode) > self.capacity:
print("Warning: episode got truncated")
ep_len = min(len(episode), self.capacity - self.idx)
idxs = slice(self.idx, self.idx + ep_len)
assert self.idx + ep_len <= self.capacity
if self.cfg.modality in {"pixels", "state"}:
self._obs[idxs] = (
episode.obses[:ep_len]
if self.cfg.modality == "state"
else episode.obses[:ep_len, -3:]
)
self._next_obs[idxs] = (
episode.obses[1 : ep_len + 1]
if self.cfg.modality == "state"
else episode.obses[1 : ep_len + 1, -3:]
)
elif self.cfg.modality == "all":
for k, v in episode.obses.items():
assert k in {"rgb", "state"}
assert k in self._obs
assert k in self._next_obs
if k == "rgb":
self._obs[k][idxs] = episode.obses[k][:ep_len, -3:]
self._next_obs[k][idxs] = episode.obses[k][1 : ep_len + 1, -3:]
else:
self._obs[k][idxs] = episode.obses[k][:ep_len]
self._next_obs[k][idxs] = episode.obses[k][1 : ep_len + 1]
self._action[idxs] = episode.actions[:ep_len]
self._reward[idxs] = episode.rewards[:ep_len]
self._mask[idxs] = episode.masks[:ep_len]
self._done[idxs] = episode.dones[:ep_len]
self._done[self.idx + ep_len - 1] = True # in case truncated
if self._full:
max_priority = (
self._priorities[: self.capacity].max().to(self.device).item()
)
else:
max_priority = (
1.0
if self.idx == 0
else self._priorities[: self.idx].max().to(self.device).item()
)
new_priorities = torch.full((ep_len,), max_priority, device=self.device)
self._priorities[idxs] = new_priorities
self.idx = (self.idx + ep_len) % self.capacity
self._full = self._full or self.idx == 0
def update_priorities(self, idxs, priorities):
"""Update priorities for Prioritized Experience Replay (PER)"""
self._priorities[idxs] = priorities.squeeze(1).to(self.device) + self._eps
def _get_obs(self, arr, idxs):
"""Retrieve observations by indices"""
if isinstance(arr, dict):
return {k: self._get_obs(v, idxs) for k, v in arr.items()}
if arr.ndim <= 2: # if self.cfg.modality == 'state':
return arr[idxs].cuda()
obs = torch.empty(
(self.cfg.batch_size, 3 * self.cfg.frame_stack, *arr.shape[-2:]),
dtype=arr.dtype,
device=torch.device("cuda"),
)
obs[:, -3:] = arr[idxs].cuda()
_idxs = idxs.clone()
mask = torch.ones_like(_idxs, dtype=torch.bool)
for i in range(1, self.cfg.frame_stack):
mask[_idxs % self.cfg.episode_length == 0] = False
_idxs[mask] -= 1
obs[:, -(i + 1) * 3 : -i * 3] = arr[_idxs].cuda()
return obs.float()
def sample(self):
"""Sample transitions from the replay buffer."""
probs = (
self._priorities[: self.capacity]
if self._full
else self._priorities[: self.idx]
) ** self.cfg.per_alpha
probs /= probs.sum()
total = len(probs)
idxs = torch.from_numpy(
np.random.choice(
total,
self.cfg.batch_size,
p=probs.cpu().numpy(),
replace=not self._full,
)
).to(self.device)
weights = (total * probs[idxs]) ** (-self.cfg.per_beta)
weights /= weights.max()
idxs_in_horizon = torch.stack([idxs + t for t in range(self.cfg.horizon)])
obs = self._aug(self._get_obs(self._obs, idxs))
next_obs = [
self._aug(self._get_obs(self._next_obs, _idxs)) for _idxs in idxs_in_horizon
]
if isinstance(next_obs[0], dict):
next_obs = {k: torch.stack([o[k] for o in next_obs]) for k in next_obs[0]}
else:
next_obs = torch.stack(next_obs)
action = self._action[idxs_in_horizon]
reward = self._reward[idxs_in_horizon]
mask = self._mask[idxs_in_horizon]
done = self._done[idxs_in_horizon]
if not action.is_cuda:
action, reward, mask, done, idxs, weights = (
action.cuda(),
reward.cuda(),
mask.cuda(),
done.cuda(),
idxs.cuda(),
weights.cuda(),
)
return (
obs,
next_obs,
action,
reward.unsqueeze(2),
mask.unsqueeze(2),
done.unsqueeze(2),
idxs,
weights,
)
def save(self, path):
"""Save the replay buffer to path"""
print(f"saving replay buffer to '{path}'...")
sz = self.capacity if self._full else self.idx
dataset = {
"observations": (
{k: v[:sz].cpu().numpy() for k, v in self._obs.items()}
if isinstance(self._obs, dict)
else self._obs[:sz].cpu().numpy()
),
"next_observations": (
{k: v[:sz].cpu().numpy() for k, v in self._next_obs.items()}
if isinstance(self._next_obs, dict)
else self._next_obs[:sz].cpu().numpy()
),
"actions": self._action[:sz].cpu().numpy(),
"rewards": self._reward[:sz].cpu().numpy(),
"dones": self._done[:sz].cpu().numpy(),
"masks": self._mask[:sz].cpu().numpy(),
}
with open(path, "wb") as f:
pickle.dump(dataset, f)
return dataset
def get_dataset_dict(cfg, env, return_reward_normalizer=False):
"""Construct a dataset for env"""
required_keys = [
"observations",
"next_observations",
"actions",
"rewards",
"dones",
"masks",
]
if cfg.task.startswith("xarm"):
dataset_path = os.path.join(cfg.dataset_dir, f"buffer.pkl")
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
for k in required_keys:
if k not in dataset_dict and k[:-1] in dataset_dict:
dataset_dict[k] = dataset_dict.pop(k[:-1])
elif cfg.task.startswith("legged"):
dataset_path = os.path.join(cfg.dataset_dir, f"buffer.pkl")
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
dataset_dict["actions"] /= env.unwrapped.clip_actions
print(f"clip_actions={env.unwrapped.clip_actions}")
else:
import d4rl
dataset_dict = d4rl.qlearning_dataset(env)
dones = np.full_like(dataset_dict["rewards"], False, dtype=bool)
for i in range(len(dones) - 1):
if (
np.linalg.norm(
dataset_dict["observations"][i + 1]
- dataset_dict["next_observations"][i]
)
> 1e-6
or dataset_dict["terminals"][i] == 1.0
):
dones[i] = True
dones[-1] = True
dataset_dict["masks"] = 1.0 - dataset_dict["terminals"]
del dataset_dict["terminals"]
for k, v in dataset_dict.items():
dataset_dict[k] = v.astype(np.float32)
dataset_dict["dones"] = dones
if cfg.is_data_clip:
lim = 1 - cfg.data_clip_eps
dataset_dict["actions"] = np.clip(dataset_dict["actions"], -lim, lim)
reward_normalizer = get_reward_normalizer(cfg, dataset_dict)
dataset_dict["rewards"] = reward_normalizer(dataset_dict["rewards"])
for key in required_keys:
assert key in dataset_dict.keys(), f"Missing `{key}` in dataset."
if return_reward_normalizer:
return dataset_dict, reward_normalizer
return dataset_dict
def get_trajectory_boundaries_and_returns(dataset):
"""
Split dataset into trajectories and compute returns
"""
episode_starts = [0]
episode_ends = []
episode_return = 0
episode_returns = []
n_transitions = len(dataset["rewards"])
for i in range(n_transitions):
episode_return += dataset["rewards"][i]
if dataset["dones"][i]:
episode_returns.append(episode_return)
episode_ends.append(i + 1)
if i + 1 < n_transitions:
episode_starts.append(i + 1)
episode_return = 0.0
return episode_starts, episode_ends, episode_returns
def normalize_returns(dataset, scaling=1000):
"""
Normalize returns in the dataset
"""
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
dataset["rewards"] /= np.max(episode_returns) - np.min(episode_returns)
dataset["rewards"] *= scaling
return dataset
def get_reward_normalizer(cfg, dataset):
"""
Get a reward normalizer for the dataset
"""
if cfg.task.startswith("xarm"):
return lambda x: x
elif "maze" in cfg.task:
return lambda x: x - 1.0
elif cfg.task.split("-")[0] in ["hopper", "halfcheetah", "walker2d"]:
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
return (
lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0
)
elif hasattr(cfg, "reward_scale"):
return lambda x: x * cfg.reward_scale
return lambda x: x
def linear_schedule(schdl, step):
"""
Outputs values following a linear decay schedule.
Adapted from https://github.com/facebookresearch/drqv2
"""
try:
return float(schdl)
except ValueError:
match = re.match(r"linear\((.+),(.+),(.+),(.+)\)", schdl)
if match:
init, final, start, end = [float(g) for g in match.groups()]
mix = np.clip((step - start) / (end - start), 0.0, 1.0)
return (1.0 - mix) * init + mix * final
match = re.match(r"linear\((.+),(.+),(.+)\)", schdl)
if match:
init, final, duration = [float(g) for g in match.groups()]
mix = np.clip(step / duration, 0.0, 1.0)
return (1.0 - mix) * init + mix * final
raise NotImplementedError(schdl)

12
lerobot/common/utils.py Normal file
View File

@ -0,0 +1,12 @@
import random
import numpy as np
import torch
def set_seed(seed):
"""Set seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

View File

@ -5,70 +5,54 @@ import imageio
import numpy as np
import torch
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from termcolor import colored
from lerobot.lib.envs.factory import make_env
from lerobot.lib.tdmpc import TDMPC
from lerobot.lib.utils import set_seed
from lerobot.common.envs.factory import make_env
from lerobot.common.tdmpc import TDMPC
from lerobot.common.utils import set_seed
def eval_agent(
env, agent, num_episodes: int, save_video: bool = False, video_path: Path = None
def eval_policy(
env, policy, num_episodes: int, save_video: bool = False, video_dir: Path = None
):
"""Evaluate a trained agent and optionally save a video."""
if save_video:
assert video_path is not None
assert video_path.suffix == ".mp4"
episode_rewards = []
episode_successes = []
episode_lengths = []
rewards = []
successes = []
for i in range(num_episodes):
td = env.reset()
obs = {}
obs["rgb"] = td["observation"]["camera"]
obs["state"] = td["observation"]["robot_state"]
ep_frames = []
done = False
ep_reward = 0
t = 0
ep_success = False
def rendering_callback(env, td=None):
nonlocal ep_frames
frame = env.render()
ep_frames.append(frame)
tensordict = env.reset()
# render first frame before rollout
rendering_callback(env)
rollout = env.rollout(
max_steps=30,
policy=policy,
callback=rendering_callback,
auto_reset=False,
tensordict=tensordict,
)
ep_reward = rollout["next", "reward"].sum()
ep_success = rollout["next", "success"].any()
rewards.append(ep_reward.item())
successes.append(ep_success.item())
if save_video:
frames = []
while not done:
action = agent.act(obs, t0=t == 0, eval_mode=True, step=100000)
td = TensorDict({"action": action}, batch_size=[])
td = env.step(td)
reward = td["next", "reward"].item()
success = td["next", "success"].item()
done = td["next", "done"].item()
obs = {}
obs["rgb"] = td["next", "observation"]["camera"]
obs["state"] = td["next", "observation"]["robot_state"]
ep_reward += reward
if success:
ep_success = True
if save_video:
frame = env.render()
frames.append(frame)
t += 1
episode_rewards.append(float(ep_reward))
episode_successes.append(float(ep_success))
episode_lengths.append(t)
if save_video:
video_path.parent.mkdir(parents=True, exist_ok=True)
frames = np.stack(frames) # .transpose(0, 3, 1, 2)
video_dir.parent.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable
imageio.mimsave(video_path, frames, fps=15)
return {
"episode_reward": np.nanmean(episode_rewards),
"episode_success": np.nanmean(episode_successes),
"episode_length": np.nanmean(episode_lengths),
video_path = video_dir / f"eval_episode_{i}.mp4"
imageio.mimsave(video_path, np.stack(ep_frames), fps=15)
metrics = {
"avg_reward": np.nanmean(rewards),
"pc_success": np.nanmean(successes) * 100,
}
return metrics
@hydra.main(version_base=None, config_name="default", config_path="../configs")
@ -78,20 +62,25 @@ def eval(cfg: dict):
print(colored("Log dir:", "yellow", attrs=["bold"]), cfg.log_dir)
env = make_env(cfg)
agent = TDMPC(cfg)
policy = TDMPC(cfg)
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
agent.load(ckpt_path)
policy.load(ckpt_path)
eval_metrics = eval_agent(
env,
agent,
num_episodes=10,
save_video=True,
video_path=Path("tmp/2023_01_29_xarm_lift_final/eval.mp4"),
policy = TensorDictModule(
policy,
in_keys=["observation", "step_count"],
out_keys=["action"],
)
print(eval_metrics)
metrics = eval_policy(
env,
policy,
num_episodes=10,
save_video=True,
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
)
print(metrics)
if __name__ == "__main__":

View File

@ -2,7 +2,10 @@ import hydra
import torch
from termcolor import colored
from ..lib.utils import set_seed
from lerobot.common.envs.factory import make_env
from lerobot.common.tdmpc import TDMPC
from ..common.utils import set_seed
@hydra.main(version_base=None, config_name="default", config_path="../configs")
@ -11,6 +14,24 @@ def train(cfg: dict):
set_seed(cfg.seed)
print(colored("Work dir:", "yellow", attrs=["bold"]), cfg.log_dir)
env = make_env(cfg)
agent = TDMPC(cfg)
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
agent.load(ckpt_path)
# online training
eval_metrics = train_agent(
env,
agent,
num_episodes=10,
save_video=True,
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
)
print(eval_metrics)
if __name__ == "__main__":
train()

View File

@ -2,7 +2,7 @@ import pytest
from tensordict import TensorDict
from torchrl.envs.utils import check_env_specs, step_mdp
from lerobot.lib.envs import SimxarmEnv
from lerobot.common.envs import SimxarmEnv
@pytest.mark.parametrize(