Merge pull request #5 from Cadene/user/aliberts/cpu_run

Add run on cpu-only compatibility
This commit is contained in:
Simon Alibert 2024-03-03 13:06:39 +01:00 committed by GitHub
commit d3aae3111c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 99 additions and 95 deletions

View File

@ -2,7 +2,7 @@ def make_policy(cfg):
if cfg.policy.name == "tdmpc": if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc import TDMPC from lerobot.common.policies.tdmpc import TDMPC
policy = TDMPC(cfg.policy) policy = TDMPC(cfg.policy, cfg.device)
elif cfg.policy.name == "diffusion": elif cfg.policy.name == "diffusion":
from lerobot.common.policies.diffusion.policy import DiffusionPolicy from lerobot.common.policies.diffusion.policy import DiffusionPolicy

View File

@ -88,14 +88,14 @@ class TOLD(nn.Module):
class TDMPC(nn.Module): class TDMPC(nn.Module):
"""Implementation of TD-MPC learning + inference.""" """Implementation of TD-MPC learning + inference."""
def __init__(self, cfg): def __init__(self, cfg, device):
super().__init__() super().__init__()
self.action_dim = cfg.action_dim self.action_dim = cfg.action_dim
self.cfg = cfg self.cfg = cfg
self.device = torch.device("cuda") self.device = torch.device(device)
self.std = h.linear_schedule(cfg.std_schedule, 0) self.std = h.linear_schedule(cfg.std_schedule, 0)
self.model = TOLD(cfg).cuda() self.model = TOLD(cfg).cuda() if torch.cuda.is_available() and device == "cuda" else TOLD(cfg)
self.model_target = deepcopy(self.model) self.model_target = deepcopy(self.model)
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr) self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)

View File

@ -341,102 +341,103 @@ class RandomShiftsAug(nn.Module):
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
class Episode: # TODO(aliberts): remove class
"""Storage object for a single episode.""" # class Episode:
# """Storage object for a single episode."""
def __init__(self, cfg, init_obs): # def __init__(self, cfg, init_obs):
action_dim = cfg.action_dim # action_dim = cfg.action_dim
self.cfg = cfg # self.cfg = cfg
self.device = torch.device(cfg.buffer_device) # self.device = torch.device(cfg.buffer_device)
if cfg.modality in {"pixels", "state"}: # if cfg.modality in {"pixels", "state"}:
dtype = torch.float32 if cfg.modality == "state" else torch.uint8 # dtype = torch.float32 if cfg.modality == "state" else torch.uint8
self.obses = torch.empty( # self.obses = torch.empty(
(cfg.episode_length + 1, *init_obs.shape), # (cfg.episode_length + 1, *init_obs.shape),
dtype=dtype, # dtype=dtype,
device=self.device, # device=self.device,
) # )
self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device) # self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device)
elif cfg.modality == "all": # elif cfg.modality == "all":
self.obses = {} # self.obses = {}
for k, v in init_obs.items(): # for k, v in init_obs.items():
assert k in {"rgb", "state"} # assert k in {"rgb", "state"}
dtype = torch.float32 if k == "state" else torch.uint8 # dtype = torch.float32 if k == "state" else torch.uint8
self.obses[k] = torch.empty( # self.obses[k] = torch.empty(
(cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device # (cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device
) # )
self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device) # self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device)
else: # else:
raise ValueError # raise ValueError
self.actions = torch.empty((cfg.episode_length, action_dim), dtype=torch.float32, device=self.device) # self.actions = torch.empty((cfg.episode_length, action_dim), dtype=torch.float32, device=self.device)
self.rewards = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device) # self.rewards = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
self.dones = torch.empty((cfg.episode_length,), dtype=torch.bool, device=self.device) # self.dones = torch.empty((cfg.episode_length,), dtype=torch.bool, device=self.device)
self.masks = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device) # self.masks = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
self.cumulative_reward = 0 # self.cumulative_reward = 0
self.done = False # self.done = False
self.success = False # self.success = False
self._idx = 0 # self._idx = 0
def __len__(self): # def __len__(self):
return self._idx # return self._idx
@classmethod # @classmethod
def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None): # def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None):
"""Constructs an episode from a trajectory.""" # """Constructs an episode from a trajectory."""
if cfg.modality in {"pixels", "state"}: # if cfg.modality in {"pixels", "state"}:
episode = cls(cfg, obses[0]) # episode = cls(cfg, obses[0])
episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device) # episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device)
elif cfg.modality == "all": # elif cfg.modality == "all":
episode = cls(cfg, {k: v[0] for k, v in obses.items()}) # episode = cls(cfg, {k: v[0] for k, v in obses.items()})
for k in obses: # for k in obses:
episode.obses[k][1:] = torch.tensor( # episode.obses[k][1:] = torch.tensor(
obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device # obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device
) # )
else: # else:
raise NotImplementedError # raise NotImplementedError
episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device) # episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device)
episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device) # episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device)
episode.dones = ( # episode.dones = (
torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device) # torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device)
if dones is not None # if dones is not None
else torch.zeros_like(episode.dones) # else torch.zeros_like(episode.dones)
) # )
episode.masks = ( # episode.masks = (
torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device) # torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device)
if masks is not None # if masks is not None
else torch.ones_like(episode.masks) # else torch.ones_like(episode.masks)
) # )
episode.cumulative_reward = torch.sum(episode.rewards) # episode.cumulative_reward = torch.sum(episode.rewards)
episode.done = True # episode.done = True
episode._idx = cfg.episode_length # episode._idx = cfg.episode_length
return episode # return episode
@property # @property
def first(self): # def first(self):
return len(self) == 0 # return len(self) == 0
def __add__(self, transition): # def __add__(self, transition):
self.add(*transition) # self.add(*transition)
return self # return self
def add(self, obs, action, reward, done, mask=1.0, success=False): # def add(self, obs, action, reward, done, mask=1.0, success=False):
"""Add a transition into the episode.""" # """Add a transition into the episode."""
if isinstance(obs, dict): # if isinstance(obs, dict):
for k, v in obs.items(): # for k, v in obs.items():
self.obses[k][self._idx + 1] = torch.tensor( # self.obses[k][self._idx + 1] = torch.tensor(
v, dtype=self.obses[k].dtype, device=self.obses[k].device # v, dtype=self.obses[k].dtype, device=self.obses[k].device
) # )
else: # else:
self.obses[self._idx + 1] = torch.tensor(obs, dtype=self.obses.dtype, device=self.obses.device) # self.obses[self._idx + 1] = torch.tensor(obs, dtype=self.obses.dtype, device=self.obses.device)
self.actions[self._idx] = action # self.actions[self._idx] = action
self.rewards[self._idx] = reward # self.rewards[self._idx] = reward
self.dones[self._idx] = done # self.dones[self._idx] = done
self.masks[self._idx] = mask # self.masks[self._idx] = mask
self.cumulative_reward += reward # self.cumulative_reward += reward
self.done = done # self.done = done
self.success = self.success or success # self.success = self.success or success
self._idx += 1 # self._idx += 1
def get_dataset_dict(cfg, env, return_reward_normalizer=False): def get_dataset_dict(cfg, env, return_reward_normalizer=False):

View File

@ -10,8 +10,7 @@ hydra:
name: default name: default
seed: 1337 seed: 1337
device: cuda device: cuda # cpu
buffer_device: cuda
prefetch: 4 prefetch: 4
eval_freq: ??? eval_freq: ???
save_freq: ??? save_freq: ???

View File

@ -115,7 +115,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
init_logging() init_logging()
assert torch.cuda.is_available() if cfg.device == "cuda":
assert torch.cuda.is_available()
else:
logging.warning("Using CPU, this will be slow.")
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed) set_seed(cfg.seed)