|
|
|
@ -341,102 +341,103 @@ class RandomShiftsAug(nn.Module):
|
|
|
|
|
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Episode:
|
|
|
|
|
"""Storage object for a single episode."""
|
|
|
|
|
# TODO(aliberts): remove class
|
|
|
|
|
# class Episode:
|
|
|
|
|
# """Storage object for a single episode."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, cfg, init_obs):
|
|
|
|
|
action_dim = cfg.action_dim
|
|
|
|
|
# def __init__(self, cfg, init_obs):
|
|
|
|
|
# action_dim = cfg.action_dim
|
|
|
|
|
|
|
|
|
|
self.cfg = cfg
|
|
|
|
|
self.device = torch.device(cfg.buffer_device)
|
|
|
|
|
if cfg.modality in {"pixels", "state"}:
|
|
|
|
|
dtype = torch.float32 if cfg.modality == "state" else torch.uint8
|
|
|
|
|
self.obses = torch.empty(
|
|
|
|
|
(cfg.episode_length + 1, *init_obs.shape),
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
device=self.device,
|
|
|
|
|
)
|
|
|
|
|
self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device)
|
|
|
|
|
elif cfg.modality == "all":
|
|
|
|
|
self.obses = {}
|
|
|
|
|
for k, v in init_obs.items():
|
|
|
|
|
assert k in {"rgb", "state"}
|
|
|
|
|
dtype = torch.float32 if k == "state" else torch.uint8
|
|
|
|
|
self.obses[k] = torch.empty(
|
|
|
|
|
(cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device
|
|
|
|
|
)
|
|
|
|
|
self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError
|
|
|
|
|
self.actions = torch.empty((cfg.episode_length, action_dim), dtype=torch.float32, device=self.device)
|
|
|
|
|
self.rewards = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
|
|
|
|
|
self.dones = torch.empty((cfg.episode_length,), dtype=torch.bool, device=self.device)
|
|
|
|
|
self.masks = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
|
|
|
|
|
self.cumulative_reward = 0
|
|
|
|
|
self.done = False
|
|
|
|
|
self.success = False
|
|
|
|
|
self._idx = 0
|
|
|
|
|
# self.cfg = cfg
|
|
|
|
|
# self.device = torch.device(cfg.buffer_device)
|
|
|
|
|
# if cfg.modality in {"pixels", "state"}:
|
|
|
|
|
# dtype = torch.float32 if cfg.modality == "state" else torch.uint8
|
|
|
|
|
# self.obses = torch.empty(
|
|
|
|
|
# (cfg.episode_length + 1, *init_obs.shape),
|
|
|
|
|
# dtype=dtype,
|
|
|
|
|
# device=self.device,
|
|
|
|
|
# )
|
|
|
|
|
# self.obses[0] = torch.tensor(init_obs, dtype=dtype, device=self.device)
|
|
|
|
|
# elif cfg.modality == "all":
|
|
|
|
|
# self.obses = {}
|
|
|
|
|
# for k, v in init_obs.items():
|
|
|
|
|
# assert k in {"rgb", "state"}
|
|
|
|
|
# dtype = torch.float32 if k == "state" else torch.uint8
|
|
|
|
|
# self.obses[k] = torch.empty(
|
|
|
|
|
# (cfg.episode_length + 1, *v.shape), dtype=dtype, device=self.device
|
|
|
|
|
# )
|
|
|
|
|
# self.obses[k][0] = torch.tensor(v, dtype=dtype, device=self.device)
|
|
|
|
|
# else:
|
|
|
|
|
# raise ValueError
|
|
|
|
|
# self.actions = torch.empty((cfg.episode_length, action_dim), dtype=torch.float32, device=self.device)
|
|
|
|
|
# self.rewards = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
|
|
|
|
|
# self.dones = torch.empty((cfg.episode_length,), dtype=torch.bool, device=self.device)
|
|
|
|
|
# self.masks = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
|
|
|
|
|
# self.cumulative_reward = 0
|
|
|
|
|
# self.done = False
|
|
|
|
|
# self.success = False
|
|
|
|
|
# self._idx = 0
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return self._idx
|
|
|
|
|
# def __len__(self):
|
|
|
|
|
# return self._idx
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None):
|
|
|
|
|
"""Constructs an episode from a trajectory."""
|
|
|
|
|
# @classmethod
|
|
|
|
|
# def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None):
|
|
|
|
|
# """Constructs an episode from a trajectory."""
|
|
|
|
|
|
|
|
|
|
if cfg.modality in {"pixels", "state"}:
|
|
|
|
|
episode = cls(cfg, obses[0])
|
|
|
|
|
episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device)
|
|
|
|
|
elif cfg.modality == "all":
|
|
|
|
|
episode = cls(cfg, {k: v[0] for k, v in obses.items()})
|
|
|
|
|
for k in obses:
|
|
|
|
|
episode.obses[k][1:] = torch.tensor(
|
|
|
|
|
obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device)
|
|
|
|
|
episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device)
|
|
|
|
|
episode.dones = (
|
|
|
|
|
torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device)
|
|
|
|
|
if dones is not None
|
|
|
|
|
else torch.zeros_like(episode.dones)
|
|
|
|
|
)
|
|
|
|
|
episode.masks = (
|
|
|
|
|
torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device)
|
|
|
|
|
if masks is not None
|
|
|
|
|
else torch.ones_like(episode.masks)
|
|
|
|
|
)
|
|
|
|
|
episode.cumulative_reward = torch.sum(episode.rewards)
|
|
|
|
|
episode.done = True
|
|
|
|
|
episode._idx = cfg.episode_length
|
|
|
|
|
return episode
|
|
|
|
|
# if cfg.modality in {"pixels", "state"}:
|
|
|
|
|
# episode = cls(cfg, obses[0])
|
|
|
|
|
# episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device)
|
|
|
|
|
# elif cfg.modality == "all":
|
|
|
|
|
# episode = cls(cfg, {k: v[0] for k, v in obses.items()})
|
|
|
|
|
# for k in obses:
|
|
|
|
|
# episode.obses[k][1:] = torch.tensor(
|
|
|
|
|
# obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device
|
|
|
|
|
# )
|
|
|
|
|
# else:
|
|
|
|
|
# raise NotImplementedError
|
|
|
|
|
# episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device)
|
|
|
|
|
# episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device)
|
|
|
|
|
# episode.dones = (
|
|
|
|
|
# torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device)
|
|
|
|
|
# if dones is not None
|
|
|
|
|
# else torch.zeros_like(episode.dones)
|
|
|
|
|
# )
|
|
|
|
|
# episode.masks = (
|
|
|
|
|
# torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device)
|
|
|
|
|
# if masks is not None
|
|
|
|
|
# else torch.ones_like(episode.masks)
|
|
|
|
|
# )
|
|
|
|
|
# episode.cumulative_reward = torch.sum(episode.rewards)
|
|
|
|
|
# episode.done = True
|
|
|
|
|
# episode._idx = cfg.episode_length
|
|
|
|
|
# return episode
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def first(self):
|
|
|
|
|
return len(self) == 0
|
|
|
|
|
# @property
|
|
|
|
|
# def first(self):
|
|
|
|
|
# return len(self) == 0
|
|
|
|
|
|
|
|
|
|
def __add__(self, transition):
|
|
|
|
|
self.add(*transition)
|
|
|
|
|
return self
|
|
|
|
|
# def __add__(self, transition):
|
|
|
|
|
# self.add(*transition)
|
|
|
|
|
# return self
|
|
|
|
|
|
|
|
|
|
def add(self, obs, action, reward, done, mask=1.0, success=False):
|
|
|
|
|
"""Add a transition into the episode."""
|
|
|
|
|
if isinstance(obs, dict):
|
|
|
|
|
for k, v in obs.items():
|
|
|
|
|
self.obses[k][self._idx + 1] = torch.tensor(
|
|
|
|
|
v, dtype=self.obses[k].dtype, device=self.obses[k].device
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
self.obses[self._idx + 1] = torch.tensor(obs, dtype=self.obses.dtype, device=self.obses.device)
|
|
|
|
|
self.actions[self._idx] = action
|
|
|
|
|
self.rewards[self._idx] = reward
|
|
|
|
|
self.dones[self._idx] = done
|
|
|
|
|
self.masks[self._idx] = mask
|
|
|
|
|
self.cumulative_reward += reward
|
|
|
|
|
self.done = done
|
|
|
|
|
self.success = self.success or success
|
|
|
|
|
self._idx += 1
|
|
|
|
|
# def add(self, obs, action, reward, done, mask=1.0, success=False):
|
|
|
|
|
# """Add a transition into the episode."""
|
|
|
|
|
# if isinstance(obs, dict):
|
|
|
|
|
# for k, v in obs.items():
|
|
|
|
|
# self.obses[k][self._idx + 1] = torch.tensor(
|
|
|
|
|
# v, dtype=self.obses[k].dtype, device=self.obses[k].device
|
|
|
|
|
# )
|
|
|
|
|
# else:
|
|
|
|
|
# self.obses[self._idx + 1] = torch.tensor(obs, dtype=self.obses.dtype, device=self.obses.device)
|
|
|
|
|
# self.actions[self._idx] = action
|
|
|
|
|
# self.rewards[self._idx] = reward
|
|
|
|
|
# self.dones[self._idx] = done
|
|
|
|
|
# self.masks[self._idx] = mask
|
|
|
|
|
# self.cumulative_reward += reward
|
|
|
|
|
# self.done = done
|
|
|
|
|
# self.success = self.success or success
|
|
|
|
|
# self._idx += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_dataset_dict(cfg, env, return_reward_normalizer=False):
|
|
|
|
|