Add run on cpu-only compatibility
This commit is contained in:
parent
661bda45ea
commit
b33ec5a630
|
@ -2,7 +2,7 @@ def make_policy(cfg):
|
|||
if cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc import TDMPC
|
||||
|
||||
policy = TDMPC(cfg.policy)
|
||||
policy = TDMPC(cfg.policy, cfg.device)
|
||||
elif cfg.policy.name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||
|
||||
|
|
|
@ -88,14 +88,14 @@ class TOLD(nn.Module):
|
|||
class TDMPC(nn.Module):
|
||||
"""Implementation of TD-MPC learning + inference."""
|
||||
|
||||
def __init__(self, cfg):
|
||||
def __init__(self, cfg, device):
|
||||
super().__init__()
|
||||
self.action_dim = cfg.action_dim
|
||||
|
||||
self.cfg = cfg
|
||||
self.device = torch.device("cuda")
|
||||
self.device = torch.device(device)
|
||||
self.std = h.linear_schedule(cfg.std_schedule, 0)
|
||||
self.model = TOLD(cfg).cuda()
|
||||
self.model = TOLD(cfg).cuda() if torch.cuda.is_available() and device == "cuda" else TOLD(cfg)
|
||||
self.model_target = deepcopy(self.model)
|
||||
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
|
||||
|
|
|
@ -341,102 +341,103 @@ class RandomShiftsAug(nn.Module):
|
|||
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
|
||||
|
||||
|
||||
class Episode:
|
||||
"""Storage object for a single episode."""
|
||||
# TODO(aliberts): remove class
|
||||
# class Episode:
|
||||
# """Storage object for a single episode."""
|
||||
|
||||
def __init__(self, cfg, init_obs):
|
||||
action_dim = cfg.action_dim
|
||||
# def __init__(self, cfg, init_obs):
|
||||
# action_dim = cfg.action_dim
|
||||
|
||||
self.cfg = cfg
|
||||
self.device = torch.device(cfg.buffer_device)
|
||||
if cfg.modality in {"pixels", "state"}:
|
||||
dtype = torch.float32 if cfg.modality == "state" else torch.uint8
|
||||
self.obses = torch.empty(
|
||||
(cfg.episode_length + 1, *init_obs.shape),
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device)
|
||||
elif cfg.modality == "all":
|
||||
self.obses = {}
|
||||
for k, v in init_obs.items():
|
||||
assert k in {"rgb", "state"}
|
||||
dtype = torch.float32 if k == "state" else torch.uint8
|
||||
self.obses[k] = torch.empty(
|
||||
(cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device
|
||||
)
|
||||
self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device)
|
||||
else:
|
||||
raise ValueError
|
||||
self.actions = torch.empty((cfg.episode_length, action_dim), dtype=torch.float32, device=self.device)
|
||||
self.rewards = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
|
||||
self.dones = torch.empty((cfg.episode_length,), dtype=torch.bool, device=self.device)
|
||||
self.masks = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
|
||||
self.cumulative_reward = 0
|
||||
self.done = False
|
||||
self.success = False
|
||||
self._idx = 0
|
||||
# self.cfg = cfg
|
||||
# self.device = torch.device(cfg.buffer_device)
|
||||
# if cfg.modality in {"pixels", "state"}:
|
||||
# dtype = torch.float32 if cfg.modality == "state" else torch.uint8
|
||||
# self.obses = torch.empty(
|
||||
# (cfg.episode_length + 1, *init_obs.shape),
|
||||
# dtype=dtype,
|
||||
# device=self.device,
|
||||
# )
|
||||
# self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device)
|
||||
# elif cfg.modality == "all":
|
||||
# self.obses = {}
|
||||
# for k, v in init_obs.items():
|
||||
# assert k in {"rgb", "state"}
|
||||
# dtype = torch.float32 if k == "state" else torch.uint8
|
||||
# self.obses[k] = torch.empty(
|
||||
# (cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device
|
||||
# )
|
||||
# self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device)
|
||||
# else:
|
||||
# raise ValueError
|
||||
# self.actions = torch.empty((cfg.episode_length, action_dim), dtype=torch.float32, device=self.device)
|
||||
# self.rewards = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
|
||||
# self.dones = torch.empty((cfg.episode_length,), dtype=torch.bool, device=self.device)
|
||||
# self.masks = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
|
||||
# self.cumulative_reward = 0
|
||||
# self.done = False
|
||||
# self.success = False
|
||||
# self._idx = 0
|
||||
|
||||
def __len__(self):
|
||||
return self._idx
|
||||
# def __len__(self):
|
||||
# return self._idx
|
||||
|
||||
@classmethod
|
||||
def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None):
|
||||
"""Constructs an episode from a trajectory."""
|
||||
# @classmethod
|
||||
# def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None):
|
||||
# """Constructs an episode from a trajectory."""
|
||||
|
||||
if cfg.modality in {"pixels", "state"}:
|
||||
episode = cls(cfg, obses[0])
|
||||
episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device)
|
||||
elif cfg.modality == "all":
|
||||
episode = cls(cfg, {k: v[0] for k, v in obses.items()})
|
||||
for k in obses:
|
||||
episode.obses[k][1:] = torch.tensor(
|
||||
obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device)
|
||||
episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device)
|
||||
episode.dones = (
|
||||
torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device)
|
||||
if dones is not None
|
||||
else torch.zeros_like(episode.dones)
|
||||
)
|
||||
episode.masks = (
|
||||
torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device)
|
||||
if masks is not None
|
||||
else torch.ones_like(episode.masks)
|
||||
)
|
||||
episode.cumulative_reward = torch.sum(episode.rewards)
|
||||
episode.done = True
|
||||
episode._idx = cfg.episode_length
|
||||
return episode
|
||||
# if cfg.modality in {"pixels", "state"}:
|
||||
# episode = cls(cfg, obses[0])
|
||||
# episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device)
|
||||
# elif cfg.modality == "all":
|
||||
# episode = cls(cfg, {k: v[0] for k, v in obses.items()})
|
||||
# for k in obses:
|
||||
# episode.obses[k][1:] = torch.tensor(
|
||||
# obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device
|
||||
# )
|
||||
# else:
|
||||
# raise NotImplementedError
|
||||
# episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device)
|
||||
# episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device)
|
||||
# episode.dones = (
|
||||
# torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device)
|
||||
# if dones is not None
|
||||
# else torch.zeros_like(episode.dones)
|
||||
# )
|
||||
# episode.masks = (
|
||||
# torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device)
|
||||
# if masks is not None
|
||||
# else torch.ones_like(episode.masks)
|
||||
# )
|
||||
# episode.cumulative_reward = torch.sum(episode.rewards)
|
||||
# episode.done = True
|
||||
# episode._idx = cfg.episode_length
|
||||
# return episode
|
||||
|
||||
@property
|
||||
def first(self):
|
||||
return len(self) == 0
|
||||
# @property
|
||||
# def first(self):
|
||||
# return len(self) == 0
|
||||
|
||||
def __add__(self, transition):
|
||||
self.add(*transition)
|
||||
return self
|
||||
# def __add__(self, transition):
|
||||
# self.add(*transition)
|
||||
# return self
|
||||
|
||||
def add(self, obs, action, reward, done, mask=1.0, success=False):
|
||||
"""Add a transition into the episode."""
|
||||
if isinstance(obs, dict):
|
||||
for k, v in obs.items():
|
||||
self.obses[k][self._idx + 1] = torch.tensor(
|
||||
v, dtype=self.obses[k].dtype, device=self.obses[k].device
|
||||
)
|
||||
else:
|
||||
self.obses[self._idx + 1] = torch.tensor(obs, dtype=self.obses.dtype, device=self.obses.device)
|
||||
self.actions[self._idx] = action
|
||||
self.rewards[self._idx] = reward
|
||||
self.dones[self._idx] = done
|
||||
self.masks[self._idx] = mask
|
||||
self.cumulative_reward += reward
|
||||
self.done = done
|
||||
self.success = self.success or success
|
||||
self._idx += 1
|
||||
# def add(self, obs, action, reward, done, mask=1.0, success=False):
|
||||
# """Add a transition into the episode."""
|
||||
# if isinstance(obs, dict):
|
||||
# for k, v in obs.items():
|
||||
# self.obses[k][self._idx + 1] = torch.tensor(
|
||||
# v, dtype=self.obses[k].dtype, device=self.obses[k].device
|
||||
# )
|
||||
# else:
|
||||
# self.obses[self._idx + 1] = torch.tensor(obs, dtype=self.obses.dtype, device=self.obses.device)
|
||||
# self.actions[self._idx] = action
|
||||
# self.rewards[self._idx] = reward
|
||||
# self.dones[self._idx] = done
|
||||
# self.masks[self._idx] = mask
|
||||
# self.cumulative_reward += reward
|
||||
# self.done = done
|
||||
# self.success = self.success or success
|
||||
# self._idx += 1
|
||||
|
||||
|
||||
def get_dataset_dict(cfg, env, return_reward_normalizer=False):
|
||||
|
|
|
@ -10,8 +10,7 @@ hydra:
|
|||
name: default
|
||||
|
||||
seed: 1337
|
||||
device: cuda
|
||||
buffer_device: cuda
|
||||
device: cuda # cpu
|
||||
prefetch: 4
|
||||
eval_freq: ???
|
||||
save_freq: ???
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import warnings
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
|
@ -115,7 +116,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
|
||||
init_logging()
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
if cfg.device == "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
else:
|
||||
warnings.warn("Using CPU, this will be slow.", UserWarning, stacklevel=1)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_seed(cfg.seed)
|
||||
|
|
Loading…
Reference in New Issue