From 5a5b190f70b2782f137766bf79d90bc720dba1c7 Mon Sep 17 00:00:00 2001 From: Cadene Date: Wed, 31 Jan 2024 13:48:12 +0000 Subject: [PATCH] Add common, refactor eval with eval_policy --- lerobot/common/__init__.py | 0 lerobot/common/envs/__init__.py | 0 lerobot/common/envs/factory.py | 42 ++ lerobot/common/envs/simxarm.py | 183 +++++++ lerobot/common/tdmpc.py | 450 +++++++++++++++++ lerobot/common/tdmpc_helper.py | 829 ++++++++++++++++++++++++++++++++ lerobot/common/utils.py | 12 + lerobot/scripts/eval.py | 113 ++--- lerobot/scripts/train.py | 23 +- test/test_envs.py | 2 +- 10 files changed, 1590 insertions(+), 64 deletions(-) create mode 100644 lerobot/common/__init__.py create mode 100644 lerobot/common/envs/__init__.py create mode 100644 lerobot/common/envs/factory.py create mode 100644 lerobot/common/envs/simxarm.py create mode 100644 lerobot/common/tdmpc.py create mode 100644 lerobot/common/tdmpc_helper.py create mode 100644 lerobot/common/utils.py diff --git a/lerobot/common/__init__.py b/lerobot/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lerobot/common/envs/__init__.py b/lerobot/common/envs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py new file mode 100644 index 00000000..de42bc26 --- /dev/null +++ b/lerobot/common/envs/factory.py @@ -0,0 +1,42 @@ +from torchrl.envs.transforms import StepCounter, TransformedEnv + +from lerobot.common.envs.simxarm import SimxarmEnv + + +def make_env(cfg): + assert cfg.env == "simxarm" + env = SimxarmEnv( + task=cfg.task, + from_pixels=cfg.from_pixels, + pixels_only=cfg.pixels_only, + image_size=cfg.image_size, + ) + + # limit rollout to max_steps + env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length)) + + return env + + +# def make_env(env_name, frame_skip, device, is_test=False): +# env = GymEnv( +# env_name, +# frame_skip=frame_skip, +# from_pixels=True, +# pixels_only=False, +# device=device, +# ) +# env = TransformedEnv(env) +# env.append_transform(NoopResetEnv(noops=30, random=True)) +# if not is_test: +# env.append_transform(EndOfLifeTransform()) +# env.append_transform(RewardClipping(-1, 1)) +# env.append_transform(ToTensorImage()) +# env.append_transform(GrayScale()) +# env.append_transform(Resize(84, 84)) +# env.append_transform(CatFrames(N=4, dim=-3)) +# env.append_transform(RewardSum()) +# env.append_transform(StepCounter(max_steps=4500)) +# env.append_transform(DoubleToFloat()) +# env.append_transform(VecNorm(in_keys=["pixels"])) +# return env diff --git a/lerobot/common/envs/simxarm.py b/lerobot/common/envs/simxarm.py new file mode 100644 index 00000000..3ca3ae0f --- /dev/null +++ b/lerobot/common/envs/simxarm.py @@ -0,0 +1,183 @@ +import importlib +from typing import Optional + +import numpy as np +import torch +from tensordict import TensorDict +from torchrl.data.tensor_specs import ( + BoundedTensorSpec, + CompositeSpec, + DiscreteTensorSpec, + UnboundedContinuousTensorSpec, +) +from torchrl.envs import EnvBase +from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform + +from lerobot.common.utils import set_seed + +_has_gym = importlib.util.find_spec("gym") is not None +_has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym + + +class SimxarmEnv(EnvBase): + + def __init__( + self, + task, + from_pixels: bool = False, + pixels_only: bool = False, + image_size=None, + seed=1337, + device="cpu", + ): + super().__init__(device=device, batch_size=[]) + self.task = task + self.from_pixels = from_pixels + self.pixels_only = pixels_only + self.image_size = image_size + + if pixels_only: + assert from_pixels + if from_pixels: + assert image_size + + if not _has_simxarm: + raise ImportError("Cannot import simxarm.") + if not _has_gym: + raise ImportError("Cannot import gym.") + + import gym + from gym.wrappers import TimeLimit + from simxarm import TASKS + + if self.task not in TASKS: + raise ValueError( + f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}" + ) + + self._env = TASKS[self.task]["env"]() + self._env = TimeLimit(self._env, TASKS[self.task]["episode_length"]) + + MAX_NUM_ACTIONS = 4 + num_actions = len(TASKS[self.task]["action_space"]) + self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,)) + self._action_padding = np.zeros( + (MAX_NUM_ACTIONS - num_actions), dtype=np.float32 + ) + if "w" not in TASKS[self.task]["action_space"]: + self._action_padding[-1] = 1.0 + + self._make_spec() + self.set_seed(seed) + + def render(self, mode="rgb_array", width=384, height=384): + return self._env.render(mode, width=width, height=height) + + def _format_raw_obs(self, raw_obs): + if self.from_pixels: + camera = self.render( + mode="rgb_array", width=self.image_size, height=self.image_size + ) + camera = camera.transpose(2, 0, 1) # (H, W, C) -> (C, H, W) + camera = torch.tensor(camera.copy(), dtype=torch.uint8) + + obs = {"camera": camera} + + if not self.pixels_only: + obs["robot_state"] = torch.tensor( + self._env.robot_state, dtype=torch.float32 + ) + else: + obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)} + + obs = TensorDict(obs, batch_size=[]) + return obs + + def _reset(self, tensordict: Optional[TensorDict] = None): + td = tensordict + if td is None or td.is_empty(): + raw_obs = self._env.reset() + + td = TensorDict( + { + "observation": self._format_raw_obs(raw_obs), + "done": torch.tensor([False], dtype=torch.bool), + }, + batch_size=[], + ) + else: + raise NotImplementedError() + return td + + def _step(self, tensordict: TensorDict): + td = tensordict + action = td["action"].numpy() + # step expects shape=(4,) so we pad if necessary + action = np.concatenate([action, self._action_padding]) + # TODO(rcadene): add info["is_success"] and info["success"] ? + raw_obs, reward, done, info = self._env.step(action) + + td = TensorDict( + { + "observation": self._format_raw_obs(raw_obs), + "reward": torch.tensor([reward], dtype=torch.float32), + "done": torch.tensor([done], dtype=torch.bool), + "success": torch.tensor([info["success"]], dtype=torch.bool), + }, + batch_size=[], + ) + return td + + def _make_spec(self): + obs = {} + if self.from_pixels: + obs["camera"] = BoundedTensorSpec( + low=0, + high=255, + shape=(3, self.image_size, self.image_size), + dtype=torch.uint8, + device=self.device, + ) + if not self.pixels_only: + obs["robot_state"] = UnboundedContinuousTensorSpec( + shape=(len(self._env.robot_state),), + dtype=torch.float32, + device=self.device, + ) + else: + # TODO(rcadene): add observation_space achieved_goal and desired_goal? + obs["state"] = UnboundedContinuousTensorSpec( + shape=self._env.observation_space["observation"].shape, + dtype=torch.float32, + device=self.device, + ) + self.observation_spec = CompositeSpec({"observation": obs}) + + self.action_spec = _gym_to_torchrl_spec_transform( + self._action_space, + device=self.device, + ) + + self.reward_spec = UnboundedContinuousTensorSpec( + shape=(1,), + dtype=torch.float32, + device=self.device, + ) + + self.done_spec = DiscreteTensorSpec( + 2, + shape=(1,), + dtype=torch.bool, + device=self.device, + ) + + self.success_spec = DiscreteTensorSpec( + 2, + shape=(1,), + dtype=torch.bool, + device=self.device, + ) + + def _set_seed(self, seed: Optional[int]): + set_seed(seed) + self._env.seed(seed) diff --git a/lerobot/common/tdmpc.py b/lerobot/common/tdmpc.py new file mode 100644 index 00000000..df4e647b --- /dev/null +++ b/lerobot/common/tdmpc.py @@ -0,0 +1,450 @@ +from copy import deepcopy + +import numpy as np +import torch +import torch.nn as nn + +import lerobot.common.tdmpc_helper as h + + +class TOLD(nn.Module): + """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC.""" + + def __init__(self, cfg): + super().__init__() + action_dim = 4 + + self.cfg = cfg + self._encoder = h.enc(cfg) + self._dynamics = h.dynamics( + cfg.latent_dim + action_dim, cfg.mlp_dim, cfg.latent_dim + ) + self._reward = h.mlp(cfg.latent_dim + action_dim, cfg.mlp_dim, 1) + self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, action_dim) + self._Qs = nn.ModuleList([h.q(cfg) for _ in range(cfg.num_q)]) + self._V = h.v(cfg) + self.apply(h.orthogonal_init) + for m in [self._reward, *self._Qs]: + m[-1].weight.data.fill_(0) + m[-1].bias.data.fill_(0) + + def track_q_grad(self, enable=True): + """Utility function. Enables/disables gradient tracking of Q-networks.""" + for m in self._Qs: + h.set_requires_grad(m, enable) + + def track_v_grad(self, enable=True): + """Utility function. Enables/disables gradient tracking of Q-networks.""" + if hasattr(self, "_V"): + h.set_requires_grad(self._V, enable) + + def encode(self, obs): + """Encodes an observation into its latent representation.""" + out = self._encoder(obs) + if isinstance(obs, dict): + # fusion + out = torch.stack([v for k, v in out.items()]).mean(dim=0) + return out + + def next(self, z, a): + """Predicts next latent state (d) and single-step reward (R).""" + x = torch.cat([z, a], dim=-1) + return self._dynamics(x), self._reward(x) + + def pi(self, z, std=0): + """Samples an action from the learned policy (pi).""" + mu = torch.tanh(self._pi(z)) + if std > 0: + std = torch.ones_like(mu) * std + return h.TruncatedNormal(mu, std).sample(clip=0.3) + return mu + + def V(self, z): + """Predict state value (V).""" + return self._V(z) + + def Q(self, z, a, return_type): + """Predict state-action value (Q).""" + assert return_type in {"min", "avg", "all"} + x = torch.cat([z, a], dim=-1) + + if return_type == "all": + return torch.stack(list(q(x) for q in self._Qs), dim=0) + + idxs = np.random.choice(self.cfg.num_q, 2, replace=False) + Q1, Q2 = self._Qs[idxs[0]](x), self._Qs[idxs[1]](x) + return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2 + + +class TDMPC(nn.Module): + """Implementation of TD-MPC learning + inference.""" + + def __init__(self, cfg): + super().__init__() + self.action_dim = 4 + + self.cfg = cfg + self.device = torch.device("cuda") + self.std = h.linear_schedule(cfg.std_schedule, 0) + self.model = TOLD(cfg).cuda() + self.model_target = deepcopy(self.model) + self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) + self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr) + self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) + self.model.eval() + self.model_target.eval() + self.batch_size = cfg.batch_size + + # TODO(rcadene): clean + self.step = 100000 + + def state_dict(self): + """Retrieve state dict of TOLD model, including slow-moving target network.""" + return { + "model": self.model.state_dict(), + "model_target": self.model_target.state_dict(), + } + + def save(self, fp): + """Save state dict of TOLD model to filepath.""" + torch.save(self.state_dict(), fp) + + def load(self, fp): + """Load a saved state dict from filepath into current agent.""" + d = torch.load(fp) + self.model.load_state_dict(d["model"]) + self.model_target.load_state_dict(d["model_target"]) + + @torch.no_grad() + def forward(self, observation, step_count): + t0 = step_count.item() == 0 + obs = { + "rgb": observation["camera"], + "state": observation["robot_state"], + } + return self.act(obs, t0=t0, step=self.step) + + @torch.no_grad() + def act(self, obs, t0=False, step=None): + """Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag.""" + if isinstance(obs, dict): + obs = { + k: torch.tensor(o, dtype=torch.float32, device=self.device).unsqueeze(0) + for k, o in obs.items() + } + else: + obs = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze( + 0 + ) + z = self.model.encode(obs) + if self.cfg.mpc: + a = self.plan(z, t0=t0, step=step) + else: + a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0) + return a.cpu() + + @torch.no_grad() + def estimate_value(self, z, actions, horizon): + """Estimate value of a trajectory starting at latent state z and executing given actions.""" + G, discount = 0, 1 + for t in range(horizon): + if self.cfg.uncertainty_cost > 0: + G -= ( + discount + * self.cfg.uncertainty_cost + * self.model.Q(z, actions[t], return_type="all").std(dim=0) + ) + z, reward = self.model.next(z, actions[t]) + G += discount * reward + discount *= self.cfg.discount + pi = self.model.pi(z, self.cfg.min_std) + G += discount * self.model.Q(z, pi, return_type="min") + if self.cfg.uncertainty_cost > 0: + G -= ( + discount + * self.cfg.uncertainty_cost + * self.model.Q(z, pi, return_type="all").std(dim=0) + ) + return G + + @torch.no_grad() + def plan(self, z, step=None, t0=True): + """ + Plan next action using TD-MPC inference. + z: latent state. + step: current time step. determines e.g. planning horizon. + t0: whether current step is the first step of an episode. + """ + # during eval: eval_mode: uniform sampling and action noise is disabled during evaluation. + + assert step is not None + # Seed steps + if step < self.cfg.seed_steps and self.model.training: + return torch.empty( + self.action_dim, dtype=torch.float32, device=self.device + ).uniform_(-1, 1) + + # Sample policy trajectories + horizon = int( + min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step)) + ) + num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples) + if num_pi_trajs > 0: + pi_actions = torch.empty( + horizon, num_pi_trajs, self.action_dim, device=self.device + ) + _z = z.repeat(num_pi_trajs, 1) + for t in range(horizon): + pi_actions[t] = self.model.pi(_z, self.cfg.min_std) + _z, _ = self.model.next(_z, pi_actions[t]) + + # Initialize state and parameters + z = z.repeat(self.cfg.num_samples + num_pi_trajs, 1) + mean = torch.zeros(horizon, self.action_dim, device=self.device) + std = self.cfg.max_std * torch.ones( + horizon, self.action_dim, device=self.device + ) + if not t0 and hasattr(self, "_prev_mean"): + mean[:-1] = self._prev_mean[1:] + + # Iterate CEM + for i in range(self.cfg.iterations): + actions = torch.clamp( + mean.unsqueeze(1) + + std.unsqueeze(1) + * torch.randn( + horizon, self.cfg.num_samples, self.action_dim, device=std.device + ), + -1, + 1, + ) + if num_pi_trajs > 0: + actions = torch.cat([actions, pi_actions], dim=1) + + # Compute elite actions + value = self.estimate_value(z, actions, horizon).nan_to_num_(0) + elite_idxs = torch.topk( + value.squeeze(1), self.cfg.num_elites, dim=0 + ).indices + elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] + + # Update parameters + max_value = elite_value.max(0)[0] + score = torch.exp(self.cfg.temperature * (elite_value - max_value)) + score /= score.sum(0) + _mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / ( + score.sum(0) + 1e-9 + ) + _std = torch.sqrt( + torch.sum( + score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2, + dim=1, + ) + / (score.sum(0) + 1e-9) + ) + _std = _std.clamp_(self.std, self.cfg.max_std) + mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std + + # Outputs + score = score.squeeze(1).cpu().numpy() + actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)] + self._prev_mean = mean + mean, std = actions[0], _std[0] + a = mean + if self.model.training: + a += std * torch.randn(self.action_dim, device=std.device) + return torch.clamp(a, -1, 1) + + def update_pi(self, zs, acts=None): + """Update policy using a sequence of latent states.""" + self.pi_optim.zero_grad(set_to_none=True) + self.model.track_q_grad(False) + self.model.track_v_grad(False) + + info = {} + # Advantage Weighted Regression + assert acts is not None + vs = self.model.V(zs) + qs = self.model_target.Q(zs, acts, return_type="min") + adv = qs - vs + exp_a = torch.exp(adv * self.cfg.A_scaling) + exp_a = torch.clamp(exp_a, max=100.0) + log_probs = h.gaussian_logprob(self.model.pi(zs) - acts, 0) + rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device)) + pi_loss = -((exp_a * log_probs).mean(dim=(1, 2)) * rho).mean() + info["adv"] = adv[0] + + pi_loss.backward() + torch.nn.utils.clip_grad_norm_( + self.model._pi.parameters(), + self.cfg.grad_clip_norm, + error_if_nonfinite=False, + ) + self.pi_optim.step() + self.model.track_q_grad(True) + self.model.track_v_grad(True) + + info["pi_loss"] = pi_loss.item() + return pi_loss.item(), info + + @torch.no_grad() + def _td_target(self, next_z, reward, mask): + """Compute the TD-target from a reward and the observation at the following time step.""" + next_v = self.model.V(next_z) + td_target = reward + self.cfg.discount * mask * next_v + return td_target + + def update(self, replay_buffer, step, demo_buffer=None): + """Main update function. Corresponds to one iteration of the model learning.""" + + if demo_buffer is not None: + # Update oversampling ratio + self.demo_batch_size = int( + h.linear_schedule(self.cfg.demo_schedule, step) * self.batch_size + ) + replay_buffer.cfg.batch_size = self.batch_size - self.demo_batch_size + demo_buffer.cfg.batch_size = self.demo_batch_size + else: + self.demo_batch_size = 0 + + # Sample from interaction dataset + obs, next_obses, action, reward, mask, done, idxs, weights = ( + replay_buffer.sample() + ) + + # Sample from demonstration dataset + if self.demo_batch_size > 0: + ( + demo_obs, + demo_next_obses, + demo_action, + demo_reward, + demo_mask, + demo_done, + demo_idxs, + demo_weights, + ) = demo_buffer.sample() + + if isinstance(obs, dict): + obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs} + next_obses = { + k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) + for k in next_obses + } + else: + obs = torch.cat([obs, demo_obs]) + next_obses = torch.cat([next_obses, demo_next_obses], dim=1) + action = torch.cat([action, demo_action], dim=1) + reward = torch.cat([reward, demo_reward], dim=1) + mask = torch.cat([mask, demo_mask], dim=1) + done = torch.cat([done, demo_done], dim=1) + idxs = torch.cat([idxs, demo_idxs]) + weights = torch.cat([weights, demo_weights]) + + horizon = self.cfg.horizon + loss_mask = torch.ones_like(mask, device=self.device) + for t in range(1, horizon): + loss_mask[t] = loss_mask[t - 1] * (~done[t - 1]) + + self.optim.zero_grad(set_to_none=True) + self.std = h.linear_schedule(self.cfg.std_schedule, step) + self.model.train() + + # Compute targets + with torch.no_grad(): + next_z = self.model.encode(next_obses) + z_targets = self.model_target.encode(next_obses) + td_targets = self._td_target(next_z, reward, mask) + + # Latent rollout + zs = torch.empty( + horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device + ) + reward_preds = torch.empty_like(reward, device=self.device) + assert reward.shape[0] == horizon + z = self.model.encode(obs) + zs[0] = z + value_info = {"Q": 0.0, "V": 0.0} + for t in range(horizon): + z, reward_pred = self.model.next(z, action[t]) + zs[t + 1] = z + reward_preds[t] = reward_pred + + with torch.no_grad(): + v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min") + + # Predictions + qs = self.model.Q(zs[:-1], action, return_type="all") + value_info["Q"] = qs.mean().item() + v = self.model.V(zs[:-1]) + value_info["V"] = v.mean().item() + + # Losses + rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view( + -1, 1, 1 + ) + consistency_loss = ( + rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask + ).sum(dim=0) + reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0) + q_value_loss, priority_loss = 0, 0 + for q in range(self.cfg.num_q): + q_value_loss += (rho * h.mse(qs[q], td_targets) * loss_mask).sum(dim=0) + priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0) + + self.expectile = h.linear_schedule(self.cfg.expectile, step) + v_value_loss = ( + rho * h.l2_expectile(v_target - v, expectile=self.expectile) * loss_mask + ).sum(dim=0) + + total_loss = ( + self.cfg.consistency_coef * consistency_loss + + self.cfg.reward_coef * reward_loss + + self.cfg.value_coef * q_value_loss + + self.cfg.value_coef * v_value_loss + ) + + weighted_loss = (total_loss.squeeze(1) * weights).mean() + weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon)) + weighted_loss.backward() + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False + ) + self.optim.step() + + if self.cfg.per: + # Update priorities + priorities = priority_loss.clamp(max=1e4).detach() + replay_buffer.update_priorities( + idxs[: replay_buffer.cfg.batch_size], + priorities[: replay_buffer.cfg.batch_size], + ) + if self.demo_batch_size > 0: + demo_buffer.update_priorities( + demo_idxs, priorities[replay_buffer.cfg.batch_size :] + ) + + # Update policy + target network + _, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action) + + if step % self.cfg.update_freq == 0: + h.ema(self.model._encoder, self.model_target._encoder, self.cfg.tau) + h.ema(self.model._Qs, self.model_target._Qs, self.cfg.tau) + + self.model.eval() + metrics = { + "consistency_loss": float(consistency_loss.mean().item()), + "reward_loss": float(reward_loss.mean().item()), + "Q_value_loss": float(q_value_loss.mean().item()), + "V_value_loss": float(v_value_loss.mean().item()), + "total_loss": float(total_loss.mean().item()), + "weighted_loss": float(weighted_loss.mean().item()), + "grad_norm": float(grad_norm), + } + for key in ["demo_batch_size", "expectile"]: + if hasattr(self, key): + metrics[key] = getattr(self, key) + metrics.update(value_info) + metrics.update(pi_update_info) + + return metrics diff --git a/lerobot/common/tdmpc_helper.py b/lerobot/common/tdmpc_helper.py new file mode 100644 index 00000000..11e5c098 --- /dev/null +++ b/lerobot/common/tdmpc_helper.py @@ -0,0 +1,829 @@ +import os +import pickle +import re + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import distributions as pyd +from torch.distributions.utils import _standard_normal + +__REDUCE__ = lambda b: "mean" if b else "none" + + +def l1(pred, target, reduce=False): + """Computes the L1-loss between predictions and targets.""" + return F.l1_loss(pred, target, reduction=__REDUCE__(reduce)) + + +def mse(pred, target, reduce=False): + """Computes the MSE loss between predictions and targets.""" + return F.mse_loss(pred, target, reduction=__REDUCE__(reduce)) + + +def l2_expectile(diff, expectile=0.7, reduce=False): + weight = torch.where(diff > 0, expectile, (1 - expectile)) + loss = weight * (diff**2) + reduction = __REDUCE__(reduce) + if reduction == "mean": + return torch.mean(loss) + elif reduction == "sum": + return torch.sum(loss) + return loss + + +def _get_out_shape(in_shape, layers): + """Utility function. Returns the output shape of a network for a given input shape.""" + x = torch.randn(*in_shape).unsqueeze(0) + return ( + (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x) + .squeeze(0) + .shape + ) + + +def gaussian_logprob(eps, log_std): + """Compute Gaussian log probability.""" + residual = (-0.5 * eps.pow(2) - log_std).sum(-1, keepdim=True) + return residual - 0.5 * np.log(2 * np.pi) * eps.size(-1) + + +def squash(mu, pi, log_pi): + """Apply squashing function.""" + mu = torch.tanh(mu) + pi = torch.tanh(pi) + log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True) + return mu, pi, log_pi + + +def orthogonal_init(m): + """Orthogonal layer initialization.""" + if isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight.data) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + gain = nn.init.calculate_gain("relu") + nn.init.orthogonal_(m.weight.data, gain) + if m.bias is not None: + nn.init.zeros_(m.bias) + + +def ema(m, m_target, tau): + """Update slow-moving average of online network (target network) at rate tau.""" + with torch.no_grad(): + for p, p_target in zip(m.parameters(), m_target.parameters()): + p_target.data.lerp_(p.data, tau) + + +def set_requires_grad(net, value): + """Enable/disable gradients for a given (sub)network.""" + for param in net.parameters(): + param.requires_grad_(value) + + +class TruncatedNormal(pyd.Normal): + """Utility class implementing the truncated normal distribution.""" + + def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): + super().__init__(loc, scale, validate_args=False) + self.low = low + self.high = high + self.eps = eps + + def _clamp(self, x): + clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) + x = x - x.detach() + clamped_x.detach() + return x + + def sample(self, clip=None, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) + eps *= self.scale + if clip is not None: + eps = torch.clamp(eps, -clip, clip) + x = self.loc + eps + return self._clamp(x) + + +class NormalizeImg(nn.Module): + """Normalizes pixel observations to [0,1) range.""" + + def __init__(self): + super().__init__() + + def forward(self, x): + return x.div(255.0) + + +class Flatten(nn.Module): + """Flattens its input to a (batched) vector.""" + + def __init__(self): + super().__init__() + + def forward(self, x): + return x.view(x.size(0), -1) + + +def enc(cfg): + obs_shape = { + "rgb": (3, cfg.img_size, cfg.img_size), + "state": (4,), + } + + """Returns a TOLD encoder.""" + pixels_enc_layers, state_enc_layers = None, None + if cfg.modality in {"pixels", "all"}: + C = int(3 * cfg.frame_stack) + pixels_enc_layers = [ + NormalizeImg(), + nn.Conv2d(C, cfg.num_channels, 7, stride=2), + nn.ReLU(), + nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2), + nn.ReLU(), + nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2), + nn.ReLU(), + nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2), + nn.ReLU(), + ] + out_shape = _get_out_shape((C, cfg.img_size, cfg.img_size), pixels_enc_layers) + pixels_enc_layers.extend( + [ + Flatten(), + nn.Linear(np.prod(out_shape), cfg.latent_dim), + nn.LayerNorm(cfg.latent_dim), + nn.Sigmoid(), + ] + ) + if cfg.modality == "pixels": + return ConvExt(nn.Sequential(*pixels_enc_layers)) + if cfg.modality in {"state", "all"}: + state_dim = obs_shape[0] if cfg.modality == "state" else obs_shape["state"][0] + state_enc_layers = [ + nn.Linear(state_dim, cfg.enc_dim), + nn.ELU(), + nn.Linear(cfg.enc_dim, cfg.latent_dim), + nn.LayerNorm(cfg.latent_dim), + nn.Sigmoid(), + ] + if cfg.modality == "state": + return nn.Sequential(*state_enc_layers) + else: + raise NotImplementedError + + encoders = {} + for k in obs_shape: + if k == "state": + encoders[k] = nn.Sequential(*state_enc_layers) + elif k.endswith("rgb"): + encoders[k] = ConvExt(nn.Sequential(*pixels_enc_layers)) + else: + raise NotImplementedError + return Multiplexer(nn.ModuleDict(encoders)) + + +def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()): + """Returns an MLP.""" + if isinstance(mlp_dim, int): + mlp_dim = [mlp_dim, mlp_dim] + return nn.Sequential( + nn.Linear(in_dim, mlp_dim[0]), + nn.LayerNorm(mlp_dim[0]), + act_fn, + nn.Linear(mlp_dim[0], mlp_dim[1]), + nn.LayerNorm(mlp_dim[1]), + act_fn, + nn.Linear(mlp_dim[1], out_dim), + ) + + +def dynamics(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()): + """Returns a dynamics network.""" + return nn.Sequential( + mlp(in_dim, mlp_dim, out_dim, act_fn), + nn.LayerNorm(out_dim), + nn.Sigmoid(), + ) + + +def q(cfg): + action_dim = 4 + """Returns a Q-function that uses Layer Normalization.""" + return nn.Sequential( + nn.Linear(cfg.latent_dim + action_dim, cfg.mlp_dim), + nn.LayerNorm(cfg.mlp_dim), + nn.Tanh(), + nn.Linear(cfg.mlp_dim, cfg.mlp_dim), + nn.ELU(), + nn.Linear(cfg.mlp_dim, 1), + ) + + +def v(cfg): + """Returns a state value function that uses Layer Normalization.""" + return nn.Sequential( + nn.Linear(cfg.latent_dim, cfg.mlp_dim), + nn.LayerNorm(cfg.mlp_dim), + nn.Tanh(), + nn.Linear(cfg.mlp_dim, cfg.mlp_dim), + nn.ELU(), + nn.Linear(cfg.mlp_dim, 1), + ) + + +def aug(cfg): + obs_shape = { + "rgb": (3, cfg.img_size, cfg.img_size), + "state": (4,), + } + + """Multiplex augmentation""" + if cfg.modality == "state": + return nn.Identity() + elif cfg.modality == "pixels": + return RandomShiftsAug(cfg) + else: + augs = {} + for k in obs_shape: + if k == "state": + augs[k] = nn.Identity() + elif k.endswith("rgb"): + augs[k] = RandomShiftsAug(cfg) + else: + raise NotImplementedError + return Multiplexer(nn.ModuleDict(augs)) + + +class ConvExt(nn.Module): + """Auxiliary conv net accommodating high-dim input""" + + def __init__(self, conv): + super().__init__() + self.conv = conv + + def forward(self, x): + if x.ndim > 4: + batch_shape = x.shape[:-3] + out = self.conv(x.view(-1, *x.shape[-3:])) + out = out.view(*batch_shape, *out.shape[1:]) + else: + out = self.conv(x) + return out + + +class Multiplexer(nn.Module): + """Model multiplexer""" + + def __init__(self, choices): + super().__init__() + self.choices = choices + + def forward(self, x, key=None): + if isinstance(x, dict): + if key is not None: + return self.choices[key](x) + return {k: self.choices[k](_x) for k, _x in x.items()} + return self.choices(x) + + +class RandomShiftsAug(nn.Module): + """ + Random shift image augmentation. + Adapted from https://github.com/facebookresearch/drqv2 + """ + + def __init__(self, cfg): + super().__init__() + assert cfg.modality in {"pixels", "all"} + self.pad = int(cfg.img_size / 21) + + def forward(self, x): + n, c, h, w = x.size() + assert h == w + padding = tuple([self.pad] * 4) + x = F.pad(x, padding, "replicate") + eps = 1.0 / (h + 2 * self.pad) + arange = torch.linspace( + -1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype + )[:h] + arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) + base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) + base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) + shift = torch.randint( + 0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype + ) + shift *= 2.0 / (h + 2 * self.pad) + grid = base_grid + shift + return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) + + +class Episode(object): + """Storage object for a single episode.""" + + def __init__(self, cfg, init_obs): + action_dim = 4 + + self.cfg = cfg + self.device = torch.device(cfg.buffer_device) + if cfg.modality in {"pixels", "state"}: + dtype = torch.float32 if cfg.modality == "state" else torch.uint8 + self.obses = torch.empty( + (cfg.episode_length + 1, *init_obs.shape), + dtype=dtype, + device=self.device, + ) + self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device) + elif cfg.modality == "all": + self.obses = {} + for k, v in init_obs.items(): + assert k in {"rgb", "state"} + dtype = torch.float32 if k == "state" else torch.uint8 + self.obses[k] = torch.empty( + (cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device + ) + self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device) + else: + raise ValueError + self.actions = torch.empty( + (cfg.episode_length, action_dim), dtype=torch.float32, device=self.device + ) + self.rewards = torch.empty( + (cfg.episode_length,), dtype=torch.float32, device=self.device + ) + self.dones = torch.empty( + (cfg.episode_length,), dtype=torch.bool, device=self.device + ) + self.masks = torch.empty( + (cfg.episode_length,), dtype=torch.float32, device=self.device + ) + self.cumulative_reward = 0 + self.done = False + self.success = False + self._idx = 0 + + def __len__(self): + return self._idx + + @classmethod + def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None): + """Constructs an episode from a trajectory.""" + + if cfg.modality in {"pixels", "state"}: + episode = cls(cfg, obses[0]) + episode.obses[1:] = torch.tensor( + obses[1:], dtype=episode.obses.dtype, device=episode.device + ) + elif cfg.modality == "all": + episode = cls(cfg, {k: v[0] for k, v in obses.items()}) + for k, v in obses.items(): + episode.obses[k][1:] = torch.tensor( + obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device + ) + else: + raise NotImplementedError + episode.actions = torch.tensor( + actions, dtype=episode.actions.dtype, device=episode.device + ) + episode.rewards = torch.tensor( + rewards, dtype=episode.rewards.dtype, device=episode.device + ) + episode.dones = ( + torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device) + if dones is not None + else torch.zeros_like(episode.dones) + ) + episode.masks = ( + torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device) + if masks is not None + else torch.ones_like(episode.masks) + ) + episode.cumulative_reward = torch.sum(episode.rewards) + episode.done = True + episode._idx = cfg.episode_length + return episode + + @property + def first(self): + return len(self) == 0 + + def __add__(self, transition): + self.add(*transition) + return self + + def add(self, obs, action, reward, done, mask=1.0, success=False): + """Add a transition into the episode.""" + if isinstance(obs, dict): + for k, v in obs.items(): + self.obses[k][self._idx + 1] = torch.tensor( + v, dtype=self.obses[k].dtype, device=self.obses[k].device + ) + else: + self.obses[self._idx + 1] = torch.tensor( + obs, dtype=self.obses.dtype, device=self.obses.device + ) + self.actions[self._idx] = action + self.rewards[self._idx] = reward + self.dones[self._idx] = done + self.masks[self._idx] = mask + self.cumulative_reward += reward + self.done = done + self.success = self.success or success + self._idx += 1 + + +class ReplayBuffer: + """ + Storage and sampling functionality. + """ + + def __init__(self, cfg, dataset=None): + action_dim = 4 + obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (4,)} + + self.cfg = cfg + self.device = torch.device(cfg.buffer_device) + print("Replay buffer device: ", self.device) + + if dataset is not None: + self.capacity = max(dataset["rewards"].shape[0], cfg.max_buffer_size) + else: + self.capacity = min(cfg.train_steps, cfg.max_buffer_size) + + if cfg.modality in {"pixels", "state"}: + dtype = torch.float32 if cfg.modality == "state" else torch.uint8 + # Note self.obs_shape always has single frame, which is different from cfg.obs_shape + self.obs_shape = ( + obs_shape if cfg.modality == "state" else (3, *obs_shape[-2:]) + ) + self._obs = torch.zeros( + (self.capacity + cfg.horizon - 1, *self.obs_shape), + dtype=dtype, + device=self.device, + ) + self._next_obs = torch.zeros( + (self.capacity + cfg.horizon - 1, *self.obs_shape), + dtype=dtype, + device=self.device, + ) + elif cfg.modality == "all": + self.obs_shape = {} + self._obs, self._next_obs = {}, {} + for k, v in obs_shape.items(): + assert k in {"rgb", "state"} + dtype = torch.float32 if k == "state" else torch.uint8 + self.obs_shape[k] = v if k == "state" else (3, *v[-2:]) + self._obs[k] = torch.zeros( + (self.capacity + cfg.horizon - 1, *self.obs_shape[k]), + dtype=dtype, + device=self.device, + ) + self._next_obs[k] = self._obs[k].clone() + else: + raise ValueError + + self._action = torch.zeros( + (self.capacity + cfg.horizon - 1, action_dim), + dtype=torch.float32, + device=self.device, + ) + self._reward = torch.zeros( + (self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device + ) + self._mask = torch.zeros( + (self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device + ) + self._done = torch.zeros( + (self.capacity + cfg.horizon - 1,), dtype=torch.bool, device=self.device + ) + self._priorities = torch.ones( + (self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device + ) + self._eps = 1e-6 + self._full = False + self.idx = 0 + if dataset is not None: + self.init_from_offline_dataset(dataset) + + self._aug = aug(cfg) + + def init_from_offline_dataset(self, dataset): + """Initialize the replay buffer from an offline dataset.""" + assert self.idx == 0 and not self._full + n_transitions = int(len(dataset["rewards"]) * self.cfg.data_first_percent) + + def copy_data(dst, src, n): + assert isinstance(dst, dict) == isinstance(src, dict) + if isinstance(dst, dict): + for k in dst: + copy_data(dst[k], src[k], n) + else: + dst[:n] = torch.from_numpy(src[:n]) + + copy_data(self._obs, dataset["observations"], n_transitions) + copy_data(self._next_obs, dataset["next_observations"], n_transitions) + copy_data(self._action, dataset["actions"], n_transitions) + copy_data(self._reward, dataset["rewards"], n_transitions) + copy_data(self._mask, dataset["masks"], n_transitions) + copy_data(self._done, dataset["dones"], n_transitions) + self.idx = (self.idx + n_transitions) % self.capacity + self._full = n_transitions >= self.capacity + + def __add__(self, episode: Episode): + self.add(episode) + return self + + def add(self, episode: Episode): + """Add an episode to the replay buffer.""" + if self.idx + len(episode) > self.capacity: + print("Warning: episode got truncated") + ep_len = min(len(episode), self.capacity - self.idx) + idxs = slice(self.idx, self.idx + ep_len) + assert self.idx + ep_len <= self.capacity + if self.cfg.modality in {"pixels", "state"}: + self._obs[idxs] = ( + episode.obses[:ep_len] + if self.cfg.modality == "state" + else episode.obses[:ep_len, -3:] + ) + self._next_obs[idxs] = ( + episode.obses[1 : ep_len + 1] + if self.cfg.modality == "state" + else episode.obses[1 : ep_len + 1, -3:] + ) + elif self.cfg.modality == "all": + for k, v in episode.obses.items(): + assert k in {"rgb", "state"} + assert k in self._obs + assert k in self._next_obs + if k == "rgb": + self._obs[k][idxs] = episode.obses[k][:ep_len, -3:] + self._next_obs[k][idxs] = episode.obses[k][1 : ep_len + 1, -3:] + else: + self._obs[k][idxs] = episode.obses[k][:ep_len] + self._next_obs[k][idxs] = episode.obses[k][1 : ep_len + 1] + self._action[idxs] = episode.actions[:ep_len] + self._reward[idxs] = episode.rewards[:ep_len] + self._mask[idxs] = episode.masks[:ep_len] + self._done[idxs] = episode.dones[:ep_len] + self._done[self.idx + ep_len - 1] = True # in case truncated + if self._full: + max_priority = ( + self._priorities[: self.capacity].max().to(self.device).item() + ) + else: + max_priority = ( + 1.0 + if self.idx == 0 + else self._priorities[: self.idx].max().to(self.device).item() + ) + new_priorities = torch.full((ep_len,), max_priority, device=self.device) + self._priorities[idxs] = new_priorities + self.idx = (self.idx + ep_len) % self.capacity + self._full = self._full or self.idx == 0 + + def update_priorities(self, idxs, priorities): + """Update priorities for Prioritized Experience Replay (PER)""" + self._priorities[idxs] = priorities.squeeze(1).to(self.device) + self._eps + + def _get_obs(self, arr, idxs): + """Retrieve observations by indices""" + if isinstance(arr, dict): + return {k: self._get_obs(v, idxs) for k, v in arr.items()} + if arr.ndim <= 2: # if self.cfg.modality == 'state': + return arr[idxs].cuda() + obs = torch.empty( + (self.cfg.batch_size, 3 * self.cfg.frame_stack, *arr.shape[-2:]), + dtype=arr.dtype, + device=torch.device("cuda"), + ) + obs[:, -3:] = arr[idxs].cuda() + _idxs = idxs.clone() + mask = torch.ones_like(_idxs, dtype=torch.bool) + for i in range(1, self.cfg.frame_stack): + mask[_idxs % self.cfg.episode_length == 0] = False + _idxs[mask] -= 1 + obs[:, -(i + 1) * 3 : -i * 3] = arr[_idxs].cuda() + return obs.float() + + def sample(self): + """Sample transitions from the replay buffer.""" + probs = ( + self._priorities[: self.capacity] + if self._full + else self._priorities[: self.idx] + ) ** self.cfg.per_alpha + probs /= probs.sum() + total = len(probs) + idxs = torch.from_numpy( + np.random.choice( + total, + self.cfg.batch_size, + p=probs.cpu().numpy(), + replace=not self._full, + ) + ).to(self.device) + weights = (total * probs[idxs]) ** (-self.cfg.per_beta) + weights /= weights.max() + + idxs_in_horizon = torch.stack([idxs + t for t in range(self.cfg.horizon)]) + + obs = self._aug(self._get_obs(self._obs, idxs)) + next_obs = [ + self._aug(self._get_obs(self._next_obs, _idxs)) for _idxs in idxs_in_horizon + ] + if isinstance(next_obs[0], dict): + next_obs = {k: torch.stack([o[k] for o in next_obs]) for k in next_obs[0]} + else: + next_obs = torch.stack(next_obs) + action = self._action[idxs_in_horizon] + reward = self._reward[idxs_in_horizon] + mask = self._mask[idxs_in_horizon] + done = self._done[idxs_in_horizon] + + if not action.is_cuda: + action, reward, mask, done, idxs, weights = ( + action.cuda(), + reward.cuda(), + mask.cuda(), + done.cuda(), + idxs.cuda(), + weights.cuda(), + ) + + return ( + obs, + next_obs, + action, + reward.unsqueeze(2), + mask.unsqueeze(2), + done.unsqueeze(2), + idxs, + weights, + ) + + def save(self, path): + """Save the replay buffer to path""" + print(f"saving replay buffer to '{path}'...") + sz = self.capacity if self._full else self.idx + dataset = { + "observations": ( + {k: v[:sz].cpu().numpy() for k, v in self._obs.items()} + if isinstance(self._obs, dict) + else self._obs[:sz].cpu().numpy() + ), + "next_observations": ( + {k: v[:sz].cpu().numpy() for k, v in self._next_obs.items()} + if isinstance(self._next_obs, dict) + else self._next_obs[:sz].cpu().numpy() + ), + "actions": self._action[:sz].cpu().numpy(), + "rewards": self._reward[:sz].cpu().numpy(), + "dones": self._done[:sz].cpu().numpy(), + "masks": self._mask[:sz].cpu().numpy(), + } + with open(path, "wb") as f: + pickle.dump(dataset, f) + return dataset + + +def get_dataset_dict(cfg, env, return_reward_normalizer=False): + """Construct a dataset for env""" + required_keys = [ + "observations", + "next_observations", + "actions", + "rewards", + "dones", + "masks", + ] + + if cfg.task.startswith("xarm"): + dataset_path = os.path.join(cfg.dataset_dir, f"buffer.pkl") + print(f"Using offline dataset '{dataset_path}'") + with open(dataset_path, "rb") as f: + dataset_dict = pickle.load(f) + for k in required_keys: + if k not in dataset_dict and k[:-1] in dataset_dict: + dataset_dict[k] = dataset_dict.pop(k[:-1]) + elif cfg.task.startswith("legged"): + dataset_path = os.path.join(cfg.dataset_dir, f"buffer.pkl") + print(f"Using offline dataset '{dataset_path}'") + with open(dataset_path, "rb") as f: + dataset_dict = pickle.load(f) + dataset_dict["actions"] /= env.unwrapped.clip_actions + print(f"clip_actions={env.unwrapped.clip_actions}") + else: + import d4rl + + dataset_dict = d4rl.qlearning_dataset(env) + dones = np.full_like(dataset_dict["rewards"], False, dtype=bool) + + for i in range(len(dones) - 1): + if ( + np.linalg.norm( + dataset_dict["observations"][i + 1] + - dataset_dict["next_observations"][i] + ) + > 1e-6 + or dataset_dict["terminals"][i] == 1.0 + ): + dones[i] = True + + dones[-1] = True + + dataset_dict["masks"] = 1.0 - dataset_dict["terminals"] + del dataset_dict["terminals"] + + for k, v in dataset_dict.items(): + dataset_dict[k] = v.astype(np.float32) + + dataset_dict["dones"] = dones + + if cfg.is_data_clip: + lim = 1 - cfg.data_clip_eps + dataset_dict["actions"] = np.clip(dataset_dict["actions"], -lim, lim) + reward_normalizer = get_reward_normalizer(cfg, dataset_dict) + dataset_dict["rewards"] = reward_normalizer(dataset_dict["rewards"]) + + for key in required_keys: + assert key in dataset_dict.keys(), f"Missing `{key}` in dataset." + + if return_reward_normalizer: + return dataset_dict, reward_normalizer + return dataset_dict + + +def get_trajectory_boundaries_and_returns(dataset): + """ + Split dataset into trajectories and compute returns + """ + episode_starts = [0] + episode_ends = [] + + episode_return = 0 + episode_returns = [] + + n_transitions = len(dataset["rewards"]) + + for i in range(n_transitions): + episode_return += dataset["rewards"][i] + + if dataset["dones"][i]: + episode_returns.append(episode_return) + episode_ends.append(i + 1) + if i + 1 < n_transitions: + episode_starts.append(i + 1) + episode_return = 0.0 + + return episode_starts, episode_ends, episode_returns + + +def normalize_returns(dataset, scaling=1000): + """ + Normalize returns in the dataset + """ + (_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset) + dataset["rewards"] /= np.max(episode_returns) - np.min(episode_returns) + dataset["rewards"] *= scaling + return dataset + + +def get_reward_normalizer(cfg, dataset): + """ + Get a reward normalizer for the dataset + """ + if cfg.task.startswith("xarm"): + return lambda x: x + elif "maze" in cfg.task: + return lambda x: x - 1.0 + elif cfg.task.split("-")[0] in ["hopper", "halfcheetah", "walker2d"]: + (_, _, episode_returns) = get_trajectory_boundaries_and_returns(dataset) + return ( + lambda x: x / (np.max(episode_returns) - np.min(episode_returns)) * 1000.0 + ) + elif hasattr(cfg, "reward_scale"): + return lambda x: x * cfg.reward_scale + return lambda x: x + + +def linear_schedule(schdl, step): + """ + Outputs values following a linear decay schedule. + Adapted from https://github.com/facebookresearch/drqv2 + """ + try: + return float(schdl) + except ValueError: + match = re.match(r"linear\((.+),(.+),(.+),(.+)\)", schdl) + if match: + init, final, start, end = [float(g) for g in match.groups()] + mix = np.clip((step - start) / (end - start), 0.0, 1.0) + return (1.0 - mix) * init + mix * final + match = re.match(r"linear\((.+),(.+),(.+)\)", schdl) + if match: + init, final, duration = [float(g) for g in match.groups()] + mix = np.clip(step / duration, 0.0, 1.0) + return (1.0 - mix) * init + mix * final + raise NotImplementedError(schdl) diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py new file mode 100644 index 00000000..a95adbc1 --- /dev/null +++ b/lerobot/common/utils.py @@ -0,0 +1,12 @@ +import random + +import numpy as np +import torch + + +def set_seed(seed): + """Set seed for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 0da0c60e..58558928 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -5,70 +5,54 @@ import imageio import numpy as np import torch from tensordict import TensorDict +from tensordict.nn import TensorDictModule from termcolor import colored -from lerobot.lib.envs.factory import make_env -from lerobot.lib.tdmpc import TDMPC -from lerobot.lib.utils import set_seed +from lerobot.common.envs.factory import make_env +from lerobot.common.tdmpc import TDMPC +from lerobot.common.utils import set_seed -def eval_agent( - env, agent, num_episodes: int, save_video: bool = False, video_path: Path = None +def eval_policy( + env, policy, num_episodes: int, save_video: bool = False, video_dir: Path = None ): - """Evaluate a trained agent and optionally save a video.""" - if save_video: - assert video_path is not None - assert video_path.suffix == ".mp4" - episode_rewards = [] - episode_successes = [] - episode_lengths = [] + rewards = [] + successes = [] for i in range(num_episodes): - td = env.reset() - obs = {} - obs["rgb"] = td["observation"]["camera"] - obs["state"] = td["observation"]["robot_state"] + ep_frames = [] - done = False - ep_reward = 0 - t = 0 - ep_success = False + def rendering_callback(env, td=None): + nonlocal ep_frames + frame = env.render() + ep_frames.append(frame) + + tensordict = env.reset() + # render first frame before rollout + rendering_callback(env) + + rollout = env.rollout( + max_steps=30, + policy=policy, + callback=rendering_callback, + auto_reset=False, + tensordict=tensordict, + ) + ep_reward = rollout["next", "reward"].sum() + ep_success = rollout["next", "success"].any() + rewards.append(ep_reward.item()) + successes.append(ep_success.item()) if save_video: - frames = [] - while not done: - action = agent.act(obs, t0=t == 0, eval_mode=True, step=100000) - td = TensorDict({"action": action}, batch_size=[]) - - td = env.step(td) - - reward = td["next", "reward"].item() - success = td["next", "success"].item() - done = td["next", "done"].item() - - obs = {} - obs["rgb"] = td["next", "observation"]["camera"] - obs["state"] = td["next", "observation"]["robot_state"] - - ep_reward += reward - if success: - ep_success = True - if save_video: - frame = env.render() - frames.append(frame) - t += 1 - episode_rewards.append(float(ep_reward)) - episode_successes.append(float(ep_success)) - episode_lengths.append(t) - if save_video: - video_path.parent.mkdir(parents=True, exist_ok=True) - frames = np.stack(frames) # .transpose(0, 3, 1, 2) + video_dir.parent.mkdir(parents=True, exist_ok=True) # TODO(rcadene): make fps configurable - imageio.mimsave(video_path, frames, fps=15) - return { - "episode_reward": np.nanmean(episode_rewards), - "episode_success": np.nanmean(episode_successes), - "episode_length": np.nanmean(episode_lengths), + video_path = video_dir / f"eval_episode_{i}.mp4" + imageio.mimsave(video_path, np.stack(ep_frames), fps=15) + + metrics = { + "avg_reward": np.nanmean(rewards), + "pc_success": np.nanmean(successes) * 100, } + return metrics @hydra.main(version_base=None, config_name="default", config_path="../configs") @@ -78,20 +62,25 @@ def eval(cfg: dict): print(colored("Log dir:", "yellow", attrs=["bold"]), cfg.log_dir) env = make_env(cfg) - agent = TDMPC(cfg) + policy = TDMPC(cfg) # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" - agent.load(ckpt_path) + policy.load(ckpt_path) - eval_metrics = eval_agent( - env, - agent, - num_episodes=10, - save_video=True, - video_path=Path("tmp/2023_01_29_xarm_lift_final/eval.mp4"), + policy = TensorDictModule( + policy, + in_keys=["observation", "step_count"], + out_keys=["action"], ) - print(eval_metrics) + metrics = eval_policy( + env, + policy, + num_episodes=10, + save_video=True, + video_dir=Path("tmp/2023_01_29_xarm_lift_final"), + ) + print(metrics) if __name__ == "__main__": diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 2a3c7970..b4a9edad 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -2,7 +2,10 @@ import hydra import torch from termcolor import colored -from ..lib.utils import set_seed +from lerobot.common.envs.factory import make_env +from lerobot.common.tdmpc import TDMPC + +from ..common.utils import set_seed @hydra.main(version_base=None, config_name="default", config_path="../configs") @@ -11,6 +14,24 @@ def train(cfg: dict): set_seed(cfg.seed) print(colored("Work dir:", "yellow", attrs=["bold"]), cfg.log_dir) + env = make_env(cfg) + agent = TDMPC(cfg) + # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" + ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" + agent.load(ckpt_path) + + # online training + + eval_metrics = train_agent( + env, + agent, + num_episodes=10, + save_video=True, + video_dir=Path("tmp/2023_01_29_xarm_lift_final"), + ) + + print(eval_metrics) + if __name__ == "__main__": train() diff --git a/test/test_envs.py b/test/test_envs.py index 0968971b..49f547b6 100644 --- a/test/test_envs.py +++ b/test/test_envs.py @@ -2,7 +2,7 @@ import pytest from tensordict import TensorDict from torchrl.envs.utils import check_env_specs, step_mdp -from lerobot.lib.envs import SimxarmEnv +from lerobot.common.envs import SimxarmEnv @pytest.mark.parametrize(