From b33ec5a6302f57350e4604d5e244a96ddf20d4e7 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sun, 3 Mar 2024 12:47:26 +0100 Subject: [PATCH] Add run on cpu-only compatibility --- lerobot/common/policies/factory.py | 2 +- lerobot/common/policies/tdmpc.py | 6 +- lerobot/common/policies/tdmpc_helper.py | 177 ++++++++++++------------ lerobot/configs/default.yaml | 3 +- lerobot/scripts/train.py | 7 +- 5 files changed, 100 insertions(+), 95 deletions(-) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 3ce207f0..15a2c21d 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -2,7 +2,7 @@ def make_policy(cfg): if cfg.policy.name == "tdmpc": from lerobot.common.policies.tdmpc import TDMPC - policy = TDMPC(cfg.policy) + policy = TDMPC(cfg.policy, cfg.device) elif cfg.policy.name == "diffusion": from lerobot.common.policies.diffusion.policy import DiffusionPolicy diff --git a/lerobot/common/policies/tdmpc.py b/lerobot/common/policies/tdmpc.py index 64908d62..42fbb825 100644 --- a/lerobot/common/policies/tdmpc.py +++ b/lerobot/common/policies/tdmpc.py @@ -88,14 +88,14 @@ class TOLD(nn.Module): class TDMPC(nn.Module): """Implementation of TD-MPC learning + inference.""" - def __init__(self, cfg): + def __init__(self, cfg, device): super().__init__() self.action_dim = cfg.action_dim self.cfg = cfg - self.device = torch.device("cuda") + self.device = torch.device(device) self.std = h.linear_schedule(cfg.std_schedule, 0) - self.model = TOLD(cfg).cuda() + self.model = TOLD(cfg).cuda() if torch.cuda.is_available() and device == "cuda" else TOLD(cfg) self.model_target = deepcopy(self.model) self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr) diff --git a/lerobot/common/policies/tdmpc_helper.py b/lerobot/common/policies/tdmpc_helper.py index 2c2ab4f2..964f1718 100644 --- a/lerobot/common/policies/tdmpc_helper.py +++ b/lerobot/common/policies/tdmpc_helper.py @@ -341,102 +341,103 @@ class RandomShiftsAug(nn.Module): return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) -class Episode: - """Storage object for a single episode.""" +# TODO(aliberts): remove class +# class Episode: +# """Storage object for a single episode.""" - def __init__(self, cfg, init_obs): - action_dim = cfg.action_dim +# def __init__(self, cfg, init_obs): +# action_dim = cfg.action_dim - self.cfg = cfg - self.device = torch.device(cfg.buffer_device) - if cfg.modality in {"pixels", "state"}: - dtype = torch.float32 if cfg.modality == "state" else torch.uint8 - self.obses = torch.empty( - (cfg.episode_length + 1, *init_obs.shape), - dtype=dtype, - device=self.device, - ) - self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device) - elif cfg.modality == "all": - self.obses = {} - for k, v in init_obs.items(): - assert k in {"rgb", "state"} - dtype = torch.float32 if k == "state" else torch.uint8 - self.obses[k] = torch.empty( - (cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device - ) - self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device) - else: - raise ValueError - self.actions = torch.empty((cfg.episode_length, action_dim), dtype=torch.float32, device=self.device) - self.rewards = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device) - self.dones = torch.empty((cfg.episode_length,), dtype=torch.bool, device=self.device) - self.masks = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device) - self.cumulative_reward = 0 - self.done = False - self.success = False - self._idx = 0 +# self.cfg = cfg +# self.device = torch.device(cfg.buffer_device) +# if cfg.modality in {"pixels", "state"}: +# dtype = torch.float32 if cfg.modality == "state" else torch.uint8 +# self.obses = torch.empty( +# (cfg.episode_length + 1, *init_obs.shape), +# dtype=dtype, +# device=self.device, +# ) +# self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device) +# elif cfg.modality == "all": +# self.obses = {} +# for k, v in init_obs.items(): +# assert k in {"rgb", "state"} +# dtype = torch.float32 if k == "state" else torch.uint8 +# self.obses[k] = torch.empty( +# (cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device +# ) +# self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device) +# else: +# raise ValueError +# self.actions = torch.empty((cfg.episode_length, action_dim), dtype=torch.float32, device=self.device) +# self.rewards = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device) +# self.dones = torch.empty((cfg.episode_length,), dtype=torch.bool, device=self.device) +# self.masks = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device) +# self.cumulative_reward = 0 +# self.done = False +# self.success = False +# self._idx = 0 - def __len__(self): - return self._idx +# def __len__(self): +# return self._idx - @classmethod - def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None): - """Constructs an episode from a trajectory.""" +# @classmethod +# def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None): +# """Constructs an episode from a trajectory.""" - if cfg.modality in {"pixels", "state"}: - episode = cls(cfg, obses[0]) - episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device) - elif cfg.modality == "all": - episode = cls(cfg, {k: v[0] for k, v in obses.items()}) - for k in obses: - episode.obses[k][1:] = torch.tensor( - obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device - ) - else: - raise NotImplementedError - episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device) - episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device) - episode.dones = ( - torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device) - if dones is not None - else torch.zeros_like(episode.dones) - ) - episode.masks = ( - torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device) - if masks is not None - else torch.ones_like(episode.masks) - ) - episode.cumulative_reward = torch.sum(episode.rewards) - episode.done = True - episode._idx = cfg.episode_length - return episode +# if cfg.modality in {"pixels", "state"}: +# episode = cls(cfg, obses[0]) +# episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device) +# elif cfg.modality == "all": +# episode = cls(cfg, {k: v[0] for k, v in obses.items()}) +# for k in obses: +# episode.obses[k][1:] = torch.tensor( +# obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device +# ) +# else: +# raise NotImplementedError +# episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device) +# episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device) +# episode.dones = ( +# torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device) +# if dones is not None +# else torch.zeros_like(episode.dones) +# ) +# episode.masks = ( +# torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device) +# if masks is not None +# else torch.ones_like(episode.masks) +# ) +# episode.cumulative_reward = torch.sum(episode.rewards) +# episode.done = True +# episode._idx = cfg.episode_length +# return episode - @property - def first(self): - return len(self) == 0 +# @property +# def first(self): +# return len(self) == 0 - def __add__(self, transition): - self.add(*transition) - return self +# def __add__(self, transition): +# self.add(*transition) +# return self - def add(self, obs, action, reward, done, mask=1.0, success=False): - """Add a transition into the episode.""" - if isinstance(obs, dict): - for k, v in obs.items(): - self.obses[k][self._idx + 1] = torch.tensor( - v, dtype=self.obses[k].dtype, device=self.obses[k].device - ) - else: - self.obses[self._idx + 1] = torch.tensor(obs, dtype=self.obses.dtype, device=self.obses.device) - self.actions[self._idx] = action - self.rewards[self._idx] = reward - self.dones[self._idx] = done - self.masks[self._idx] = mask - self.cumulative_reward += reward - self.done = done - self.success = self.success or success - self._idx += 1 +# def add(self, obs, action, reward, done, mask=1.0, success=False): +# """Add a transition into the episode.""" +# if isinstance(obs, dict): +# for k, v in obs.items(): +# self.obses[k][self._idx + 1] = torch.tensor( +# v, dtype=self.obses[k].dtype, device=self.obses[k].device +# ) +# else: +# self.obses[self._idx + 1] = torch.tensor(obs, dtype=self.obses.dtype, device=self.obses.device) +# self.actions[self._idx] = action +# self.rewards[self._idx] = reward +# self.dones[self._idx] = done +# self.masks[self._idx] = mask +# self.cumulative_reward += reward +# self.done = done +# self.success = self.success or success +# self._idx += 1 def get_dataset_dict(cfg, env, return_reward_normalizer=False): diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 97f560f5..0001a49e 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -10,8 +10,7 @@ hydra: name: default seed: 1337 -device: cuda -buffer_device: cuda +device: cuda # cpu prefetch: 4 eval_freq: ??? save_freq: ??? diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index a537835e..c2eef313 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -1,4 +1,5 @@ import logging +import warnings import hydra import numpy as np @@ -115,7 +116,11 @@ def train(cfg: dict, out_dir=None, job_name=None): init_logging() - assert torch.cuda.is_available() + if cfg.device == "cuda": + assert torch.cuda.is_available() + else: + warnings.warn("Using CPU, this will be slow.", UserWarning, stacklevel=1) + torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True set_seed(cfg.seed)