Add common, refactor eval with eval_policy

This commit is contained in:
Cadene 2024-01-31 13:48:12 +00:00
parent 1e52499490
commit 5a5b190f70
10 changed files with 1590 additions and 64 deletions

View File

View File

View File

@ -0,0 +1,42 @@
from torchrl.envs.transforms import StepCounter, TransformedEnv
from lerobot.common.envs.simxarm import SimxarmEnv
def make_env(cfg):
assert cfg.env == "simxarm"
env = SimxarmEnv(
task=cfg.task,
from_pixels=cfg.from_pixels,
pixels_only=cfg.pixels_only,
image_size=cfg.image_size,
)
# limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length))
return env
# def make_env(env_name, frame_skip, device, is_test=False):
# env = GymEnv(
# env_name,
# frame_skip=frame_skip,
# from_pixels=True,
# pixels_only=False,
# device=device,
# )
# env = TransformedEnv(env)
# env.append_transform(NoopResetEnv(noops=30, random=True))
# if not is_test:
# env.append_transform(EndOfLifeTransform())
# env.append_transform(RewardClipping(-1, 1))
# env.append_transform(ToTensorImage())
# env.append_transform(GrayScale())
# env.append_transform(Resize(84, 84))
# env.append_transform(CatFrames(N=4, dim=-3))
# env.append_transform(RewardSum())
# env.append_transform(StepCounter(max_steps=4500))
# env.append_transform(DoubleToFloat())
# env.append_transform(VecNorm(in_keys=["pixels"]))
# return env

View File

@ -0,0 +1,183 @@
import importlib
from typing import Optional
import numpy as np
import torch
from tensordict import TensorDict
from torchrl.data.tensor_specs import (
BoundedTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import EnvBase
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
from lerobot.common.utils import set_seed
_has_gym = importlib.util.find_spec("gym") is not None
_has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym
class SimxarmEnv(EnvBase):
def __init__(
self,
task,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
seed=1337,
device="cpu",
):
super().__init__(device=device, batch_size=[])
self.task = task
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
if pixels_only:
assert from_pixels
if from_pixels:
assert image_size
if not _has_simxarm:
raise ImportError("Cannot import simxarm.")
if not _has_gym:
raise ImportError("Cannot import gym.")
import gym
from gym.wrappers import TimeLimit
from simxarm import TASKS
if self.task not in TASKS:
raise ValueError(
f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}"
)
self._env = TASKS[self.task]["env"]()
self._env = TimeLimit(self._env, TASKS[self.task]["episode_length"])
MAX_NUM_ACTIONS = 4
num_actions = len(TASKS[self.task]["action_space"])
self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
self._action_padding = np.zeros(
(MAX_NUM_ACTIONS - num_actions), dtype=np.float32
)
if "w" not in TASKS[self.task]["action_space"]:
self._action_padding[-1] = 1.0
self._make_spec()
self.set_seed(seed)
def render(self, mode="rgb_array", width=384, height=384):
return self._env.render(mode, width=width, height=height)
def _format_raw_obs(self, raw_obs):
if self.from_pixels:
camera = self.render(
mode="rgb_array", width=self.image_size, height=self.image_size
)
camera = camera.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
camera = torch.tensor(camera.copy(), dtype=torch.uint8)
obs = {"camera": camera}
if not self.pixels_only:
obs["robot_state"] = torch.tensor(
self._env.robot_state, dtype=torch.float32
)
else:
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
obs = TensorDict(obs, batch_size=[])
return obs
def _reset(self, tensordict: Optional[TensorDict] = None):
td = tensordict
if td is None or td.is_empty():
raw_obs = self._env.reset()
td = TensorDict(
{
"observation": self._format_raw_obs(raw_obs),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
else:
raise NotImplementedError()
return td
def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# step expects shape=(4,) so we pad if necessary
action = np.concatenate([action, self._action_padding])
# TODO(rcadene): add info["is_success"] and info["success"] ?
raw_obs, reward, done, info = self._env.step(action)
td = TensorDict(
{
"observation": self._format_raw_obs(raw_obs),
"reward": torch.tensor([reward], dtype=torch.float32),
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([info["success"]], dtype=torch.bool),
},
batch_size=[],
)
return td
def _make_spec(self):
obs = {}
if self.from_pixels:
obs["camera"] = BoundedTensorSpec(
low=0,
high=255,
shape=(3, self.image_size, self.image_size),
dtype=torch.uint8,
device=self.device,
)
if not self.pixels_only:
obs["robot_state"] = UnboundedContinuousTensorSpec(
shape=(len(self._env.robot_state),),
dtype=torch.float32,
device=self.device,
)
else:
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
obs["state"] = UnboundedContinuousTensorSpec(
shape=self._env.observation_space["observation"].shape,
dtype=torch.float32,
device=self.device,
)
self.observation_spec = CompositeSpec({"observation": obs})
self.action_spec = _gym_to_torchrl_spec_transform(
self._action_space,
device=self.device,
)
self.reward_spec = UnboundedContinuousTensorSpec(
shape=(1,),
dtype=torch.float32,
device=self.device,
)
self.done_spec = DiscreteTensorSpec(
2,
shape=(1,),
dtype=torch.bool,
device=self.device,
)
self.success_spec = DiscreteTensorSpec(
2,
shape=(1,),
dtype=torch.bool,
device=self.device,
)
def _set_seed(self, seed: Optional[int]):
set_seed(seed)
self._env.seed(seed)

450
lerobot/common/tdmpc.py Normal file
View File

@ -0,0 +1,450 @@
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
import lerobot.common.tdmpc_helper as h
class TOLD(nn.Module):
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
def __init__(self, cfg):
super().__init__()
action_dim = 4
self.cfg = cfg
self._encoder = h.enc(cfg)
self._dynamics = h.dynamics(
cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim
)
self._reward = h.mlp(cfg.latent_dim + action_dim, cfg.mlp_dim, 1)
self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, action_dim)
self._Qs = nn.ModuleList([h.q(cfg) for _ in range(cfg.num_q)])
self._V = h.v(cfg)
self.apply(h.orthogonal_init)
for m in [self._reward, *self._Qs]:
m[-1].weight.data.fill_(0)
m[-1].bias.data.fill_(0)
def track_q_grad(self, enable=True):
"""Utility function. Enables/disables gradient tracking of Q-networks."""
for m in self._Qs:
h.set_requires_grad(m, enable)
def track_v_grad(self, enable=True):
"""Utility function. Enables/disables gradient tracking of Q-networks."""
if hasattr(self, "_V"):
h.set_requires_grad(self._V, enable)
def encode(self, obs):
"""Encodes an observation into its latent representation."""
out = self._encoder(obs)
if isinstance(obs, dict):
# fusion
out = torch.stack([v for k, v in out.items()]).mean(dim=0)
return out
def next(self, z, a):
"""Predicts next latent state (d) and single-step reward (R)."""
x = torch.cat([z, a], dim=-1)
return self._dynamics(x), self._reward(x)
def pi(self, z, std=0):
"""Samples an action from the learned policy (pi)."""
mu = torch.tanh(self._pi(z))
if std > 0:
std = torch.ones_like(mu) * std
return h.TruncatedNormal(mu, std).sample(clip=0.3)
return mu
def V(self, z):
"""Predict state value (V)."""
return self._V(z)
def Q(self, z, a, return_type):
"""Predict state-action value (Q)."""
assert return_type in {"min", "avg", "all"}
x = torch.cat([z, a], dim=-1)
if return_type == "all":
return torch.stack(list(q(x) for q in self._Qs), dim=0)
idxs = np.random.choice(self.cfg.num_q, 2, replace=False)
Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x)
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
class TDMPC(nn.Module):
"""Implementation of TD-MPC learning + inference."""
def __init__(self, cfg):
super().__init__()
self.action_dim = 4
self.cfg = cfg
self.device = torch.device("cuda")
self.std = h.linear_schedule(cfg.std_schedule, 0)
self.model = TOLD(cfg).cuda()
self.model_target = deepcopy(self.model)
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.model.eval()
self.model_target.eval()
self.batch_size = cfg.batch_size
# TODO(rcadene): clean
self.step = 100000
def state_dict(self):
"""Retrieve state dict of TOLD model, including slow-moving target network."""
return {
"model": self.model.state_dict(),
"model_target": self.model_target.state_dict(),
}
def save(self, fp):
"""Save state dict of TOLD model to filepath."""
torch.save(self.state_dict(), fp)
def load(self, fp):
"""Load a saved state dict from filepath into current agent."""
d = torch.load(fp)
self.model.load_state_dict(d["model"])
self.model_target.load_state_dict(d["model_target"])
@torch.no_grad()
def forward(self, observation, step_count):
t0 = step_count.item() == 0
obs = {
"rgb": observation["camera"],
"state": observation["robot_state"],
}
return self.act(obs, t0=t0, step=self.step)
@torch.no_grad()
def act(self, obs, t0=False, step=None):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
if isinstance(obs, dict):
obs = {
k: torch.tensor(o, dtype=torch.float32, device=self.device).unsqueeze(0)
for k, o in obs.items()
}
else:
obs = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(
0
)
z = self.model.encode(obs)
if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step)
else:
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
return a.cpu()
@torch.no_grad()
def estimate_value(self, z, actions, horizon):
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
G, discount = 0, 1
for t in range(horizon):
if self.cfg.uncertainty_cost > 0:
G -= (
discount
* self.cfg.uncertainty_cost
* self.model.Q(z, actions[t], return_type="all").std(dim=0)
)
z, reward = self.model.next(z, actions[t])
G += discount * reward
discount *= self.cfg.discount
pi = self.model.pi(z, self.cfg.min_std)
G += discount * self.model.Q(z, pi, return_type="min")
if self.cfg.uncertainty_cost > 0:
G -= (
discount
* self.cfg.uncertainty_cost
* self.model.Q(z, pi, return_type="all").std(dim=0)
)
return G
@torch.no_grad()
def plan(self, z, step=None, t0=True):
"""
Plan next action using TD-MPC inference.
z: latent state.
step: current time step. determines e.g. planning horizon.
t0: whether current step is the first step of an episode.
"""
# during eval: eval_mode: uniform sampling and action noise is disabled during evaluation.
assert step is not None
# Seed steps
if step < self.cfg.seed_steps and self.model.training:
return torch.empty(
self.action_dim, dtype=torch.float32, device=self.device
).uniform_(-1, 1)
# Sample policy trajectories
horizon = int(
min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step))
)
num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
if num_pi_trajs > 0:
pi_actions = torch.empty(
horizon, num_pi_trajs, self.action_dim, device=self.device
)
_z = z.repeat(num_pi_trajs, 1)
for t in range(horizon):
pi_actions[t] = self.model.pi(_z, self.cfg.min_std)
_z, _ = self.model.next(_z, pi_actions[t])
# Initialize state and parameters
z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1)
mean = torch.zeros(horizon, self.action_dim, device=self.device)
std = self.cfg.max_std * torch.ones(
horizon, self.action_dim, device=self.device
)
if not t0 and hasattr(self, "_prev_mean"):
mean[:-1] = self._prev_mean[1:]
# Iterate CEM
for i in range(self.cfg.iterations):
actions = torch.clamp(
mean.unsqueeze(1)
+ std.unsqueeze(1)
* torch.randn(
horizon, self.cfg.num_samples, self.action_dim, device=std.device
),
-1,
1,
)
if num_pi_trajs > 0:
actions = torch.cat([actions, pi_actions], dim=1)
# Compute elite actions
value = self.estimate_value(z, actions, horizon).nan_to_num_(0)
elite_idxs = torch.topk(
value.squeeze(1), self.cfg.num_elites, dim=0
).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
# Update parameters
max_value = elite_value.max(0)[0]
score = torch.exp(self.cfg.temperature * (elite_value - max_value))
score /= score.sum(0)
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (
score.sum(0) + 1e-9
)
_std = torch.sqrt(
torch.sum(
score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2,
dim=1,
)
/ (score.sum(0) + 1e-9)
)
_std = _std.clamp_(self.std, self.cfg.max_std)
mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std
# Outputs
score = score.squeeze(1).cpu().numpy()
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
self._prev_mean = mean
mean, std = actions[0], _std[0]
a = mean
if self.model.training:
a += std * torch.randn(self.action_dim, device=std.device)
return torch.clamp(a, -1, 1)
def update_pi(self, zs, acts=None):
"""Update policy using a sequence of latent states."""
self.pi_optim.zero_grad(set_to_none=True)
self.model.track_q_grad(False)
self.model.track_v_grad(False)
info = {}
# Advantage Weighted Regression
assert acts is not None
vs = self.model.V(zs)
qs = self.model_target.Q(zs, acts, return_type="min")
adv = qs - vs
exp_a = torch.exp(adv * self.cfg.A_scaling)
exp_a = torch.clamp(exp_a, max=100.0)
log_probs = h.gaussian_logprob(self.model.pi(zs) - acts, 0)
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
pi_loss = -((exp_a * log_probs).mean(dim=(1, 2)) * rho).mean()
info["adv"] = adv[0]
pi_loss.backward()
torch.nn.utils.clip_grad_norm_(
self.model._pi.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
self.pi_optim.step()
self.model.track_q_grad(True)
self.model.track_v_grad(True)
info["pi_loss"] = pi_loss.item()
return pi_loss.item(), info
@torch.no_grad()
def _td_target(self, next_z, reward, mask):
"""Compute the TD-target from a reward and the observation at the following time step."""
next_v = self.model.V(next_z)
td_target = reward + self.cfg.discount * mask * next_v
return td_target
def update(self, replay_buffer, step, demo_buffer=None):
"""Main update function. Corresponds to one iteration of the model learning."""
if demo_buffer is not None:
# Update oversampling ratio
self.demo_batch_size = int(
h.linear_schedule(self.cfg.demo_schedule, step) * self.batch_size
)
replay_buffer.cfg.batch_size = self.batch_size - self.demo_batch_size
demo_buffer.cfg.batch_size = self.demo_batch_size
else:
self.demo_batch_size = 0
# Sample from interaction dataset
obs, next_obses, action, reward, mask, done, idxs, weights = (
replay_buffer.sample()
)
# Sample from demonstration dataset
if self.demo_batch_size > 0:
(
demo_obs,
demo_next_obses,
demo_action,
demo_reward,
demo_mask,
demo_done,
demo_idxs,
demo_weights,
) = demo_buffer.sample()
if isinstance(obs, dict):
obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
next_obses = {
k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1)
for k in next_obses
}
else:
obs = torch.cat([obs, demo_obs])
next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
action = torch.cat([action, demo_action], dim=1)
reward = torch.cat([reward, demo_reward], dim=1)
mask = torch.cat([mask, demo_mask], dim=1)
done = torch.cat([done, demo_done], dim=1)
idxs = torch.cat([idxs, demo_idxs])
weights = torch.cat([weights, demo_weights])
horizon = self.cfg.horizon
loss_mask = torch.ones_like(mask, device=self.device)
for t in range(1, horizon):
loss_mask[t] = loss_mask[t - 1] * (~done[t - 1])
self.optim.zero_grad(set_to_none=True)
self.std = h.linear_schedule(self.cfg.std_schedule, step)
self.model.train()
# Compute targets
with torch.no_grad():
next_z = self.model.encode(next_obses)
z_targets = self.model_target.encode(next_obses)
td_targets = self._td_target(next_z, reward, mask)
# Latent rollout
zs = torch.empty(
horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device
)
reward_preds = torch.empty_like(reward, device=self.device)
assert reward.shape[0] == horizon
z = self.model.encode(obs)
zs[0] = z
value_info = {"Q": 0.0, "V": 0.0}
for t in range(horizon):
z, reward_pred = self.model.next(z, action[t])
zs[t + 1] = z
reward_preds[t] = reward_pred
with torch.no_grad():
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
# Predictions
qs = self.model.Q(zs[:-1], action, return_type="all")
value_info["Q"] = qs.mean().item()
v = self.model.V(zs[:-1])
value_info["V"] = v.mean().item()
# Losses
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(
-1, 1, 1
)
consistency_loss = (
rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask
).sum(dim=0)
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
q_value_loss, priority_loss = 0, 0
for q in range(self.cfg.num_q):
q_value_loss += (rho * h.mse(qs[q], td_targets) * loss_mask).sum(dim=0)
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
self.expectile = h.linear_schedule(self.cfg.expectile, step)
v_value_loss = (
rho * h.l2_expectile(v_target - v, expectile=self.expectile) * loss_mask
).sum(dim=0)
total_loss = (
self.cfg.consistency_coef * consistency_loss
+ self.cfg.reward_coef * reward_loss
+ self.cfg.value_coef * q_value_loss
+ self.cfg.value_coef * v_value_loss
)
weighted_loss = (total_loss.squeeze(1) * weights).mean()
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
weighted_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
)
self.optim.step()
if self.cfg.per:
# Update priorities
priorities = priority_loss.clamp(max=1e4).detach()
replay_buffer.update_priorities(
idxs[: replay_buffer.cfg.batch_size],
priorities[: replay_buffer.cfg.batch_size],
)
if self.demo_batch_size > 0:
demo_buffer.update_priorities(
demo_idxs, priorities[replay_buffer.cfg.batch_size :]
)
# Update policy + target network
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
if step % self.cfg.update_freq == 0:
h.ema(self.model._encoder, self.model_target._encoder, self.cfg.tau)
h.ema(self.model._Qs, self.model_target._Qs, self.cfg.tau)
self.model.eval()
metrics = {
"consistency_loss": float(consistency_loss.mean().item()),
"reward_loss": float(reward_loss.mean().item()),
"Q_value_loss": float(q_value_loss.mean().item()),
"V_value_loss": float(v_value_loss.mean().item()),
"total_loss": float(total_loss.mean().item()),
"weighted_loss": float(weighted_loss.mean().item()),
"grad_norm": float(grad_norm),
}
for key in ["demo_batch_size", "expectile"]:
if hasattr(self, key):
metrics[key] = getattr(self, key)
metrics.update(value_info)
metrics.update(pi_update_info)
return metrics

View File

@ -0,0 +1,829 @@
import os
import pickle
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions as pyd
from torch.distributions.utils import _standard_normal
__REDUCE__ = lambda b: "mean" if b else "none"
def l1(pred, target, reduce=False):
"""Computes the L1-loss between predictions and targets."""
return F.l1_loss(pred, target, reduction=__REDUCE__(reduce))
def mse(pred, target, reduce=False):
"""Computes the MSE loss between predictions and targets."""
return F.mse_loss(pred, target, reduction=__REDUCE__(reduce))
def l2_expectile(diff, expectile=0.7, reduce=False):
weight = torch.where(diff > 0, expectile, (1 - expectile))
loss = weight * (diff**2)
reduction = __REDUCE__(reduce)
if reduction == "mean":
return torch.mean(loss)
elif reduction == "sum":
return torch.sum(loss)
return loss
def _get_out_shape(in_shape, layers):
"""Utility function. Returns the output shape of a network for a given input shape."""
x = torch.randn(*in_shape).unsqueeze(0)
return (
(nn.Sequential(*layers) if isinstance(layers, list) else layers)(x)
.squeeze(0)
.shape
)
def gaussian_logprob(eps, log_std):
"""Compute Gaussian log probability."""
residual = (-0.5 * eps.pow(2) - log_std).sum(-1, keepdim=True)
return residual - 0.5 * np.log(2 * np.pi) * eps.size(-1)
def squash(mu, pi, log_pi):
"""Apply squashing function."""
mu = torch.tanh(mu)
pi = torch.tanh(pi)
log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
return mu, pi, log_pi
def orthogonal_init(m):
"""Orthogonal layer initialization."""
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
gain = nn.init.calculate_gain("relu")
nn.init.orthogonal_(m.weight.data, gain)
if m.bias is not None:
nn.init.zeros_(m.bias)
def ema(m, m_target, tau):
"""Update slow-moving average of online network (target network) at rate tau."""
with torch.no_grad():
for p, p_target in zip(m.parameters(), m_target.parameters()):
p_target.data.lerp_(p.data, tau)
def set_requires_grad(net, value):
"""Enable/disable gradients for a given (sub)network."""
for param in net.parameters():
param.requires_grad_(value)
class TruncatedNormal(pyd.Normal):
"""Utility class implementing the truncated normal distribution."""
def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
super().__init__(loc, scale, validate_args=False)
self.low = low
self.high = high
self.eps = eps
def _clamp(self, x):
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
x = x - x.detach() + clamped_x.detach()
return x
def sample(self, clip=None, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
eps *= self.scale
if clip is not None:
eps = torch.clamp(eps, -clip, clip)
x = self.loc + eps
return self._clamp(x)
class NormalizeImg(nn.Module):
"""Normalizes pixel observations to [0,1) range."""
def __init__(self):
super().__init__()
def forward(self, x):
return x.div(255.0)
class Flatten(nn.Module):
"""Flattens its input to a (batched) vector."""
def __init__(self):
super().__init__()
def forward(self, x):
return x.view(x.size(0), -1)
def enc(cfg):
obs_shape = {
"rgb": (3, cfg.img_size, cfg.img_size),
"state": (4,),
}
"""Returns a TOLD encoder."""
pixels_enc_layers, state_enc_layers = None, None
if cfg.modality in {"pixels", "all"}:
C = int(3 * cfg.frame_stack)
pixels_enc_layers = [
NormalizeImg(),
nn.Conv2d(C, cfg.num_channels, 7, stride=2),
nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2),
nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2),
nn.ReLU(),
nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2),
nn.ReLU(),
]
out_shape = _get_out_shape((C, cfg.img_size, cfg.img_size), pixels_enc_layers)
pixels_enc_layers.extend(
[
Flatten(),
nn.Linear(np.prod(out_shape), cfg.latent_dim),
nn.LayerNorm(cfg.latent_dim),
nn.Sigmoid(),
]
)
if cfg.modality == "pixels":
return ConvExt(nn.Sequential(*pixels_enc_layers))
if cfg.modality in {"state", "all"}:
state_dim = obs_shape[0] if cfg.modality == "state" else obs_shape["state"][0]
state_enc_layers = [
nn.Linear(state_dim, cfg.enc_dim),
nn.ELU(),
nn.Linear(cfg.enc_dim, cfg.latent_dim),
nn.LayerNorm(cfg.latent_dim),
nn.Sigmoid(),
]
if cfg.modality == "state":
return nn.Sequential(*state_enc_layers)
else:
raise NotImplementedError
encoders = {}
for k in obs_shape:
if k == "state":
encoders[k] = nn.Sequential(*state_enc_layers)
elif k.endswith("rgb"):
encoders[k] = ConvExt(nn.Sequential(*pixels_enc_layers))
else:
raise NotImplementedError
return Multiplexer(nn.ModuleDict(encoders))
def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
"""Returns an MLP."""
if isinstance(mlp_dim, int):
mlp_dim = [mlp_dim, mlp_dim]
return nn.Sequential(
nn.Linear(in_dim, mlp_dim[0]),
nn.LayerNorm(mlp_dim[0]),
act_fn,
nn.Linear(mlp_dim[0], mlp_dim[1]),
nn.LayerNorm(mlp_dim[1]),
act_fn,
nn.Linear(mlp_dim[1], out_dim),
)
def dynamics(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
"""Returns a dynamics network."""
return nn.Sequential(
mlp(in_dim, mlp_dim, out_dim, act_fn),
nn.LayerNorm(out_dim),
nn.Sigmoid(),
)
def q(cfg):
action_dim = 4
"""Returns a Q-function that uses Layer Normalization."""
return nn.Sequential(
nn.Linear(cfg.latent_dim + action_dim, cfg.mlp_dim),
nn.LayerNorm(cfg.mlp_dim),
nn.Tanh(),
nn.Linear(cfg.mlp_dim, cfg.mlp_dim),
nn.ELU(),
nn.Linear(cfg.mlp_dim, 1),
)
def v(cfg):
"""Returns a state value function that uses Layer Normalization."""
return nn.Sequential(
nn.Linear(cfg.latent_dim, cfg.mlp_dim),
nn.LayerNorm(cfg.mlp_dim),
nn.Tanh(),
nn.Linear(cfg.mlp_dim, cfg.mlp_dim),
nn.ELU(),
nn.Linear(cfg.mlp_dim, 1),
)
def aug(cfg):
obs_shape = {
"rgb": (3, cfg.img_size, cfg.img_size),
"state": (4,),
}
"""Multiplex augmentation"""
if cfg.modality == "state":
return nn.Identity()
elif cfg.modality == "pixels":
return RandomShiftsAug(cfg)
else:
augs = {}
for k in obs_shape:
if k == "state":
augs[k] = nn.Identity()
elif k.endswith("rgb"):
augs[k] = RandomShiftsAug(cfg)
else:
raise NotImplementedError
return Multiplexer(nn.ModuleDict(augs))
class ConvExt(nn.Module):
"""Auxiliary conv net accommodating high-dim input"""
def __init__(self, conv):
super().__init__()
self.conv = conv
def forward(self, x):
if x.ndim > 4:
batch_shape = x.shape[:-3]
out = self.conv(x.view(-1, *x.shape[-3:]))
out = out.view(*batch_shape, *out.shape[1:])
else:
out = self.conv(x)
return out
class Multiplexer(nn.Module):
"""Model multiplexer"""
def __init__(self, choices):
super().__init__()
self.choices = choices
def forward(self, x, key=None):
if isinstance(x, dict):
if key is not None:
return self.choices[key](x)
return {k: self.choices[k](_x) for k, _x in x.items()}
return self.choices(x)
class RandomShiftsAug(nn.Module):
"""
Random shift image augmentation.
Adapted from https://github.com/facebookresearch/drqv2
"""
def __init__(self, cfg):
super().__init__()
assert cfg.modality in {"pixels", "all"}
self.pad = int(cfg.img_size / 21)
def forward(self, x):
n, c, h, w = x.size()
assert h == w
padding = tuple([self.pad] * 4)
x = F.pad(x, padding, "replicate")
eps = 1.0 / (h + 2 * self.pad)
arange = torch.linspace(
-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype
)[:h]
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
shift = torch.randint(
0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype
)
shift *= 2.0 / (h + 2 * self.pad)
grid = base_grid + shift
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
class Episode(object):
"""Storage object for a single episode."""
def __init__(self, cfg, init_obs):
action_dim = 4
self.cfg = cfg
self.device = torch.device(cfg.buffer_device)
if cfg.modality in {"pixels", "state"}:
dtype = torch.float32 if cfg.modality == "state" else torch.uint8
self.obses = torch.empty(
(cfg.episode_length + 1, *init_obs.shape),
dtype=dtype,
device=self.device,
)
self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device)
elif cfg.modality == "all":
self.obses = {}
for k, v in init_obs.items():
assert k in {"rgb", "state"}
dtype = torch.float32 if k == "state" else torch.uint8
self.obses[k] = torch.empty(
(cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device
)
self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device)
else:
raise ValueError
self.actions = torch.empty(
(cfg.episode_length, action_dim), dtype=torch.float32, device=self.device
)
self.rewards = torch.empty(
(cfg.episode_length,), dtype=torch.float32, device=self.device
)
self.dones = torch.empty(
(cfg.episode_length,), dtype=torch.bool, device=self.device
)
self.masks = torch.empty(
(cfg.episode_length,), dtype=torch.float32, device=self.device
)
self.cumulative_reward = 0
self.done = False
self.success = False
self._idx = 0
def __len__(self):
return self._idx
@classmethod
def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None):
"""Constructs an episode from a trajectory."""
if cfg.modality in {"pixels", "state"}:
episode = cls(cfg, obses[0])
episode.obses[1:] = torch.tensor(
obses[1:], dtype=episode.obses.dtype, device=episode.device
)
elif cfg.modality == "all":
episode = cls(cfg, {k: v[0] for k, v in obses.items()})
for k, v in obses.items():
episode.obses[k][1:] = torch.tensor(
obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device
)
else:
raise NotImplementedError
episode.actions = torch.tensor(
actions, dtype=episode.actions.dtype, device=episode.device
)
episode.rewards = torch.tensor(
rewards, dtype=episode.rewards.dtype, device=episode.device
)
episode.dones = (
torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device)
if dones is not None
else torch.zeros_like(episode.dones)
)
episode.masks = (
torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device)
if masks is not None
else torch.ones_like(episode.masks)
)
episode.cumulative_reward = torch.sum(episode.rewards)
episode.done = True
episode._idx = cfg.episode_length
return episode
@property
def first(self):
return len(self) == 0
def __add__(self, transition):
self.add(*transition)
return self
def add(self, obs, action, reward, done, mask=1.0, success=False):
"""Add a transition into the episode."""
if isinstance(obs, dict):
for k, v in obs.items():
self.obses[k][self._idx + 1] = torch.tensor(
v, dtype=self.obses[k].dtype, device=self.obses[k].device
)
else:
self.obses[self._idx + 1] = torch.tensor(
obs, dtype=self.obses.dtype, device=self.obses.device
)
self.actions[self._idx] = action
self.rewards[self._idx] = reward
self.dones[self._idx] = done
self.masks[self._idx] = mask
self.cumulative_reward += reward
self.done = done
self.success = self.success or success
self._idx += 1
class ReplayBuffer:
"""
Storage and sampling functionality.
"""
def __init__(self, cfg, dataset=None):
action_dim = 4
obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (4,)}
self.cfg = cfg
self.device = torch.device(cfg.buffer_device)
print("Replay buffer device: ", self.device)
if dataset is not None:
self.capacity = max(dataset["rewards"].shape[0], cfg.max_buffer_size)
else:
self.capacity = min(cfg.train_steps, cfg.max_buffer_size)
if cfg.modality in {"pixels", "state"}:
dtype = torch.float32 if cfg.modality == "state" else torch.uint8
# Note self.obs_shape always has single frame, which is different from cfg.obs_shape
self.obs_shape = (
obs_shape if cfg.modality == "state" else (3, *obs_shape[-2:])
)
self._obs = torch.zeros(
(self.capacity + cfg.horizon - 1, *self.obs_shape),
dtype=dtype,
device=self.device,
)
self._next_obs = torch.zeros(
(self.capacity + cfg.horizon - 1, *self.obs_shape),
dtype=dtype,
device=self.device,
)
elif cfg.modality == "all":
self.obs_shape = {}
self._obs, self._next_obs = {}, {}
for k, v in obs_shape.items():
assert k in {"rgb", "state"}
dtype = torch.float32 if k == "state" else torch.uint8
self.obs_shape[k] = v if k == "state" else (3, *v[-2:])
self._obs[k] = torch.zeros(
(self.capacity + cfg.horizon - 1, *self.obs_shape[k]),
dtype=dtype,
device=self.device,
)
self._next_obs[k] = self._obs[k].clone()
else:
raise ValueError
self._action = torch.zeros(
(self.capacity + cfg.horizon - 1, action_dim),
dtype=torch.float32,
device=self.device,
)
self._reward = torch.zeros(
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
)
self._mask = torch.zeros(
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
)
self._done = torch.zeros(
(self.capacity + cfg.horizon - 1,), dtype=torch.bool, device=self.device
)
self._priorities = torch.ones(
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
)
self._eps = 1e-6
self._full = False
self.idx = 0
if dataset is not None:
self.init_from_offline_dataset(dataset)
self._aug = aug(cfg)
def init_from_offline_dataset(self, dataset):
"""Initialize the replay buffer from an offline dataset."""
assert self.idx == 0 and not self._full
n_transitions = int(len(dataset["rewards"]) * self.cfg.data_first_percent)
def copy_data(dst, src, n):
assert isinstance(dst, dict) == isinstance(src, dict)
if isinstance(dst, dict):
for k in dst:
copy_data(dst[k], src[k], n)
else:
dst[:n] = torch.from_numpy(src[:n])
copy_data(self._obs, dataset["observations"], n_transitions)
copy_data(self._next_obs, dataset["next_observations"], n_transitions)
copy_data(self._action, dataset["actions"], n_transitions)
copy_data(self._reward, dataset["rewards"], n_transitions)
copy_data(self._mask, dataset["masks"], n_transitions)
copy_data(self._done, dataset["dones"], n_transitions)
self.idx = (self.idx + n_transitions) % self.capacity
self._full = n_transitions >= self.capacity
def __add__(self, episode: Episode):
self.add(episode)
return self
def add(self, episode: Episode):
"""Add an episode to the replay buffer."""
if self.idx + len(episode) > self.capacity:
print("Warning: episode got truncated")
ep_len = min(len(episode), self.capacity - self.idx)
idxs = slice(self.idx, self.idx + ep_len)
assert self.idx + ep_len <= self.capacity
if self.cfg.modality in {"pixels", "state"}:
self._obs[idxs] = (
episode.obses[:ep_len]
if self.cfg.modality == "state"
else episode.obses[:ep_len, -3:]
)
self._next_obs[idxs] = (
episode.obses[1 : ep_len + 1]
if self.cfg.modality == "state"
else episode.obses[1 : ep_len + 1, -3:]
)
elif self.cfg.modality == "all":
for k, v in episode.obses.items():
assert k in {"rgb", "state"}
assert k in self._obs
assert k in self._next_obs
if k == "rgb":
self._obs[k][idxs] = episode.obses[k][:ep_len, -3:]
self._next_obs[k][idxs] = episode.obses[k][1 : ep_len + 1, -3:]
else:
self._obs[k][idxs] = episode.obses[k][:ep_len]
self._next_obs[k][idxs] = episode.obses[k][1 : ep_len + 1]
self._action[idxs] = episode.actions[:ep_len]
self._reward[idxs] = episode.rewards[:ep_len]
self._mask[idxs] = episode.masks[:ep_len]
self._done[idxs] = episode.dones[:ep_len]
self._done[self.idx + ep_len - 1] = True # in case truncated
if self._full:
max_priority = (
self._priorities[: self.capacity].max().to(self.device).item()
)
else:
max_priority = (
1.0
if self.idx == 0
else self._priorities[: self.idx].max().to(self.device).item()
)
new_priorities = torch.full((ep_len,), max_priority, device=self.device)
self._priorities[idxs] = new_priorities
self.idx = (self.idx + ep_len) % self.capacity
self._full = self._full or self.idx == 0
def update_priorities(self, idxs, priorities):
"""Update priorities for Prioritized Experience Replay (PER)"""
self._priorities[idxs] = priorities.squeeze(1).to(self.device) + self._eps
def _get_obs(self, arr, idxs):
"""Retrieve observations by indices"""
if isinstance(arr, dict):
return {k: self._get_obs(v, idxs) for k, v in arr.items()}
if arr.ndim <= 2: # if self.cfg.modality == 'state':
return arr[idxs].cuda()
obs = torch.empty(
(self.cfg.batch_size, 3 * self.cfg.frame_stack, *arr.shape[-2:]),
dtype=arr.dtype,
device=torch.device("cuda"),
)
obs[:, -3:] = arr[idxs].cuda()
_idxs = idxs.clone()
mask = torch.ones_like(_idxs, dtype=torch.bool)
for i in range(1, self.cfg.frame_stack):
mask[_idxs % self.cfg.episode_length == 0] = False
_idxs[mask] -= 1
obs[:, -(i + 1) * 3 : -i * 3] = arr[_idxs].cuda()
return obs.float()
def sample(self):
"""Sample transitions from the replay buffer."""
probs = (
self._priorities[: self.capacity]
if self._full
else self._priorities[: self.idx]
) ** self.cfg.per_alpha
probs /= probs.sum()
total = len(probs)
idxs = torch.from_numpy(
np.random.choice(
total,
self.cfg.batch_size,
p=probs.cpu().numpy(),
replace=not self._full,
)
).to(self.device)
weights = (total * probs[idxs]) ** (-self.cfg.per_beta)
weights /= weights.max()
idxs_in_horizon = torch.stack([idxs + t for t in range(self.cfg.horizon)])
obs = self._aug(self._get_obs(self._obs, idxs))
next_obs = [
self._aug(self._get_obs(self._next_obs, _idxs)) for _idxs in idxs_in_horizon
]
if isinstance(next_obs[0], dict):
next_obs = {k: torch.stack([o[k] for o in next_obs]) for k in next_obs[0]}
else:
next_obs = torch.stack(next_obs)
action = self._action[idxs_in_horizon]
reward = self._reward[idxs_in_horizon]
mask = self._mask[idxs_in_horizon]
done = self._done[idxs_in_horizon]
if not action.is_cuda:
action, reward, mask, done, idxs, weights = (
action.cuda(),
reward.cuda(),
mask.cuda(),
done.cuda(),
idxs.cuda(),
weights.cuda(),
)
return (
obs,
next_obs,
action,
reward.unsqueeze(2),
mask.unsqueeze(2),
done.unsqueeze(2),
idxs,
weights,
)
def save(self, path):
"""Save the replay buffer to path"""
print(f"saving replay buffer to '{path}'...")
sz = self.capacity if self._full else self.idx
dataset = {
"observations": (
{k: v[:sz].cpu().numpy() for k, v in self._obs.items()}
if isinstance(self._obs, dict)
else self._obs[:sz].cpu().numpy()
),
"next_observations": (
{k: v[:sz].cpu().numpy() for k, v in self._next_obs.items()}
if isinstance(self._next_obs, dict)
else self._next_obs[:sz].cpu().numpy()
),
"actions": self._action[:sz].cpu().numpy(),
"rewards": self._reward[:sz].cpu().numpy(),
"dones": self._done[:sz].cpu().numpy(),
"masks": self._mask[:sz].cpu().numpy(),
}
with open(path, "wb") as f:
pickle.dump(dataset, f)
return dataset
def get_dataset_dict(cfg, env, return_reward_normalizer=False):
"""Construct a dataset for env"""
required_keys = [
"observations",
"next_observations",
"actions",
"rewards",
"dones",
"masks",
]
if cfg.task.startswith("xarm"):
dataset_path = os.path.join(cfg.dataset_dir, f"buffer.pkl")
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
for k in required_keys:
if k not in dataset_dict and k[:-1] in dataset_dict:
dataset_dict[k] = dataset_dict.pop(k[:-1])
elif cfg.task.startswith("legged"):
dataset_path = os.path.join(cfg.dataset_dir, f"buffer.pkl")
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
dataset_dict["actions"] /= env.unwrapped.clip_actions
print(f"clip_actions={env.unwrapped.clip_actions}")
else:
import d4rl
dataset_dict = d4rl.qlearning_dataset(env)
dones = np.full_like(dataset_dict["rewards"], False, dtype=bool)
for i in range(len(dones) - 1):
if (
np.linalg.norm(
dataset_dict["observations"][i + 1]
- dataset_dict["next_observations"][i]
)
> 1e-6
or dataset_dict["terminals"][i] == 1.0
):
dones[i] = True
dones[-1] = True
dataset_dict["masks"] = 1.0 - dataset_dict["terminals"]
del dataset_dict["terminals"]
for k, v in dataset_dict.items():
dataset_dict[k] = v.astype(np.float32)
dataset_dict["dones"] = dones
if cfg.is_data_clip:
lim = 1 - cfg.data_clip_eps
dataset_dict["actions"] = np.clip(dataset_dict["actions"], -lim, lim)
reward_normalizer = get_reward_normalizer(cfg, dataset_dict)
dataset_dict["rewards"] = reward_normalizer(dataset_dict["rewards"])
for key in required_keys:
assert key in dataset_dict.keys(), f"Missing `{key}` in dataset."
if return_reward_normalizer:
return dataset_dict, reward_normalizer
return dataset_dict
def get_trajectory_boundaries_and_returns(dataset):
"""
Split dataset into trajectories and compute returns
"""
episode_starts = [0]
episode_ends = []
episode_return = 0
episode_returns = []
n_transitions = len(dataset["rewards"])
for i in range(n_transitions):
episode_return += dataset["rewards"][i]
if dataset["dones"][i]:
episode_returns.append(episode_return)
episode_ends.append(i + 1)
if i + 1 < n_transitions:
episode_starts.append(i + 1)
episode_return = 0.0
return episode_starts, episode_ends, episode_returns
def normalize_returns(dataset, scaling=1000):
"""
Normalize returns in the dataset
"""
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
dataset["rewards"] /= np.max(episode_returns) - np.min(episode_returns)
dataset["rewards"] *= scaling
return dataset
def get_reward_normalizer(cfg, dataset):
"""
Get a reward normalizer for the dataset
"""
if cfg.task.startswith("xarm"):
return lambda x: x
elif "maze" in cfg.task:
return lambda x: x - 1.0
elif cfg.task.split("-")[0] in ["hopper", "halfcheetah", "walker2d"]:
(_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset)
return (
lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0
)
elif hasattr(cfg, "reward_scale"):
return lambda x: x * cfg.reward_scale
return lambda x: x
def linear_schedule(schdl, step):
"""
Outputs values following a linear decay schedule.
Adapted from https://github.com/facebookresearch/drqv2
"""
try:
return float(schdl)
except ValueError:
match = re.match(r"linear\((.+),(.+),(.+),(.+)\)", schdl)
if match:
init, final, start, end = [float(g) for g in match.groups()]
mix = np.clip((step - start) / (end - start), 0.0, 1.0)
return (1.0 - mix) * init + mix * final
match = re.match(r"linear\((.+),(.+),(.+)\)", schdl)
if match:
init, final, duration = [float(g) for g in match.groups()]
mix = np.clip(step / duration, 0.0, 1.0)
return (1.0 - mix) * init + mix * final
raise NotImplementedError(schdl)

12
lerobot/common/utils.py Normal file
View File

@ -0,0 +1,12 @@
import random
import numpy as np
import torch
def set_seed(seed):
"""Set seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

View File

@ -5,70 +5,54 @@ import imageio
import numpy as np import numpy as np
import torch import torch
from tensordict import TensorDict from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from termcolor import colored from termcolor import colored
from lerobot.lib.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.lib.tdmpc import TDMPC from lerobot.common.tdmpc import TDMPC
from lerobot.lib.utils import set_seed from lerobot.common.utils import set_seed
def eval_agent( def eval_policy(
env, agent, num_episodes: int, save_video: bool = False, video_path: Path = None env, policy, num_episodes: int, save_video: bool = False, video_dir: Path = None
): ):
"""Evaluate a trained agent and optionally save a video.""" rewards = []
if save_video: successes = []
assert video_path is not None
assert video_path.suffix == ".mp4"
episode_rewards = []
episode_successes = []
episode_lengths = []
for i in range(num_episodes): for i in range(num_episodes):
td = env.reset() ep_frames = []
obs = {}
obs["rgb"] = td["observation"]["camera"]
obs["state"] = td["observation"]["robot_state"]
done = False def rendering_callback(env, td=None):
ep_reward = 0 nonlocal ep_frames
t = 0 frame = env.render()
ep_success = False ep_frames.append(frame)
tensordict = env.reset()
# render first frame before rollout
rendering_callback(env)
rollout = env.rollout(
max_steps=30,
policy=policy,
callback=rendering_callback,
auto_reset=False,
tensordict=tensordict,
)
ep_reward = rollout["next", "reward"].sum()
ep_success = rollout["next", "success"].any()
rewards.append(ep_reward.item())
successes.append(ep_success.item())
if save_video: if save_video:
frames = [] video_dir.parent.mkdir(parents=True, exist_ok=True)
while not done:
action = agent.act(obs, t0=t == 0, eval_mode=True, step=100000)
td = TensorDict({"action": action}, batch_size=[])
td = env.step(td)
reward = td["next", "reward"].item()
success = td["next", "success"].item()
done = td["next", "done"].item()
obs = {}
obs["rgb"] = td["next", "observation"]["camera"]
obs["state"] = td["next", "observation"]["robot_state"]
ep_reward += reward
if success:
ep_success = True
if save_video:
frame = env.render()
frames.append(frame)
t += 1
episode_rewards.append(float(ep_reward))
episode_successes.append(float(ep_success))
episode_lengths.append(t)
if save_video:
video_path.parent.mkdir(parents=True, exist_ok=True)
frames = np.stack(frames) # .transpose(0, 3, 1, 2)
# TODO(rcadene): make fps configurable # TODO(rcadene): make fps configurable
imageio.mimsave(video_path, frames, fps=15) video_path = video_dir / f"eval_episode_{i}.mp4"
return { imageio.mimsave(video_path, np.stack(ep_frames), fps=15)
"episode_reward": np.nanmean(episode_rewards),
"episode_success": np.nanmean(episode_successes), metrics = {
"episode_length": np.nanmean(episode_lengths), "avg_reward": np.nanmean(rewards),
"pc_success": np.nanmean(successes) * 100,
} }
return metrics
@hydra.main(version_base=None, config_name="default", config_path="../configs") @hydra.main(version_base=None, config_name="default", config_path="../configs")
@ -78,20 +62,25 @@ def eval(cfg: dict):
print(colored("Log dir:", "yellow", attrs=["bold"]), cfg.log_dir) print(colored("Log dir:", "yellow", attrs=["bold"]), cfg.log_dir)
env = make_env(cfg) env = make_env(cfg)
agent = TDMPC(cfg) policy = TDMPC(cfg)
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
agent.load(ckpt_path) policy.load(ckpt_path)
eval_metrics = eval_agent( policy = TensorDictModule(
env, policy,
agent, in_keys=["observation", "step_count"],
num_episodes=10, out_keys=["action"],
save_video=True,
video_path=Path("tmp/2023_01_29_xarm_lift_final/eval.mp4"),
) )
print(eval_metrics) metrics = eval_policy(
env,
policy,
num_episodes=10,
save_video=True,
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
)
print(metrics)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2,7 +2,10 @@ import hydra
import torch import torch
from termcolor import colored from termcolor import colored
from ..lib.utils import set_seed from lerobot.common.envs.factory import make_env
from lerobot.common.tdmpc import TDMPC
from ..common.utils import set_seed
@hydra.main(version_base=None, config_name="default", config_path="../configs") @hydra.main(version_base=None, config_name="default", config_path="../configs")
@ -11,6 +14,24 @@ def train(cfg: dict):
set_seed(cfg.seed) set_seed(cfg.seed)
print(colored("Work dir:", "yellow", attrs=["bold"]), cfg.log_dir) print(colored("Work dir:", "yellow", attrs=["bold"]), cfg.log_dir)
env = make_env(cfg)
agent = TDMPC(cfg)
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
agent.load(ckpt_path)
# online training
eval_metrics = train_agent(
env,
agent,
num_episodes=10,
save_video=True,
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
)
print(eval_metrics)
if __name__ == "__main__": if __name__ == "__main__":
train() train()

View File

@ -2,7 +2,7 @@ import pytest
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.envs.utils import check_env_specs, step_mdp from torchrl.envs.utils import check_env_specs, step_mdp
from lerobot.lib.envs import SimxarmEnv from lerobot.common.envs import SimxarmEnv
@pytest.mark.parametrize( @pytest.mark.parametrize(