From f56b1a0e1630867691fd07329cd52c257ee3d44c Mon Sep 17 00:00:00 2001 From: Cadene Date: Fri, 5 Apr 2024 13:40:31 +0000 Subject: [PATCH] WIP tdmpc --- lerobot/common/policies/tdmpc/policy.py | 207 +++++++++++++----------- lerobot/configs/env/simxarm.yaml | 2 +- lerobot/configs/policy/tdmpc.yaml | 2 +- sbatch.sh | 4 +- 4 files changed, 121 insertions(+), 94 deletions(-) diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index 64dcc94d..85700913 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -1,6 +1,7 @@ # ruff: noqa: N806 import time +from collections import deque from copy import deepcopy import einops @@ -9,7 +10,7 @@ import torch import torch.nn as nn import lerobot.common.policies.tdmpc.helper as h -from lerobot.common.policies.abstract import AbstractPolicy +from lerobot.common.policies.utils import populate_queues from lerobot.common.utils import get_safe_torch_device FIRST_FRAME = 0 @@ -87,16 +88,18 @@ class TOLD(nn.Module): return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2 -class TDMPCPolicy(AbstractPolicy): +class TDMPCPolicy(nn.Module): """Implementation of TD-MPC learning + inference.""" name = "tdmpc" - def __init__(self, cfg, device): - super().__init__(None) + def __init__(self, cfg, n_obs_steps, n_action_steps, device): + super().__init__() self.action_dim = cfg.action_dim self.cfg = cfg + self.n_obs_steps = n_obs_steps + self.n_action_steps = n_action_steps self.device = get_safe_torch_device(device) self.std = h.linear_schedule(cfg.std_schedule, 0) self.model = TOLD(cfg) @@ -128,20 +131,42 @@ class TDMPCPolicy(AbstractPolicy): self.model.load_state_dict(d["model"]) self.model_target.load_state_dict(d["model_target"]) - @torch.no_grad() - def select_actions(self, observation, step_count): - if observation["image"].shape[0] != 1: - raise NotImplementedError("Batch size > 1 not handled") - - t0 = step_count.item() == 0 - - obs = { - # TODO(rcadene): remove contiguous hack... - "rgb": observation["image"].contiguous(), - "state": observation["state"].contiguous(), + def reset(self): + """ + Clear observation and action queues. Should be called on `env.reset()` + """ + self._queues = { + "observation.image": deque(maxlen=self.n_obs_steps), + "observation.state": deque(maxlen=self.n_obs_steps), + "action": deque(maxlen=self.n_action_steps), } - # Note: unsqueeze needed because `act` still uses non-batch logic. - action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0) + + @torch.no_grad() + def select_action(self, batch, step): + assert "observation.image" in batch + assert "observation.state" in batch + assert len(batch) == 2 + + self._queues = populate_queues(self._queues, batch) + + t0 = step == 0 + + if len(self._queues["action"]) == 0: + batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} + + actions = [] + batch_size = batch["observation.image."].shape[0] + for i in range(batch_size): + obs = { + "rgb": batch["observation.image"][[i]], + "state": batch["observation.state"][[i]], + } + # Note: unsqueeze needed because `act` still uses non-batch logic. + action = self.act(obs, t0=t0, step=self.step) + actions.append(action) + action = torch.stack(actions) + + action = self._queues["action"].popleft() return action @torch.no_grad() @@ -293,97 +318,97 @@ class TDMPCPolicy(AbstractPolicy): td_target = reward + self.cfg.discount * mask * next_v return td_target - def update(self, replay_buffer, step, demo_buffer=None): + def forward(self, batch, step): """Main update function. Corresponds to one iteration of the model learning.""" start_time = time.time() - num_slices = self.cfg.batch_size - batch_size = self.cfg.horizon * num_slices + # num_slices = self.cfg.batch_size + # batch_size = self.cfg.horizon * num_slices - if demo_buffer is None: - demo_batch_size = 0 - else: - # Update oversampling ratio - demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step) - demo_num_slices = int(demo_pc_batch * self.batch_size) - demo_batch_size = self.cfg.horizon * demo_num_slices - batch_size -= demo_batch_size - num_slices -= demo_num_slices - replay_buffer._sampler.num_slices = num_slices - demo_buffer._sampler.num_slices = demo_num_slices + # if demo_buffer is None: + # demo_batch_size = 0 + # else: + # # Update oversampling ratio + # demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step) + # demo_num_slices = int(demo_pc_batch * self.batch_size) + # demo_batch_size = self.cfg.horizon * demo_num_slices + # batch_size -= demo_batch_size + # num_slices -= demo_num_slices + # replay_buffer._sampler.num_slices = num_slices + # demo_buffer._sampler.num_slices = demo_num_slices - assert demo_batch_size % self.cfg.horizon == 0 - assert demo_batch_size % demo_num_slices == 0 + # assert demo_batch_size % self.cfg.horizon == 0 + # assert demo_batch_size % demo_num_slices == 0 - assert batch_size % self.cfg.horizon == 0 - assert batch_size % num_slices == 0 + # assert batch_size % self.cfg.horizon == 0 + # assert batch_size % num_slices == 0 - # Sample from interaction dataset + # # Sample from interaction dataset - def process_batch(batch, horizon, num_slices): - # trajectory t = 256, horizon h = 5 - # (t h) ... -> h t ... - batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous() + # def process_batch(batch, horizon, num_slices): + # # trajectory t = 256, horizon h = 5 + # # (t h) ... -> h t ... + # batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous() - obs = { - "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True), - "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True), - } - action = batch["action"].to(self.device, non_blocking=True) - next_obses = { - "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True), - "state": batch["next", "observation", "state"].to(self.device, non_blocking=True), - } - reward = batch["next", "reward"].to(self.device, non_blocking=True) + # obs = { + # "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True), + # "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True), + # } + # action = batch["action"].to(self.device, non_blocking=True) + # next_obses = { + # "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True), + # "state": batch["next", "observation", "state"].to(self.device, non_blocking=True), + # } + # reward = batch["next", "reward"].to(self.device, non_blocking=True) - idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True) - weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True) + # idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True) + # weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True) - # TODO(rcadene): rearrange directly in offline dataset - if reward.ndim == 2: - reward = einops.rearrange(reward, "h t -> h t 1") + # # TODO(rcadene): rearrange directly in offline dataset + # if reward.ndim == 2: + # reward = einops.rearrange(reward, "h t -> h t 1") - assert reward.ndim == 3 - assert reward.shape == (horizon, num_slices, 1) - # We dont use `batch["next", "done"]` since it only indicates the end of an - # episode, but not the end of the trajectory of an episode. - # Neither does `batch["next", "terminated"]` - done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) - mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) - return obs, action, next_obses, reward, mask, done, idxs, weights + # assert reward.ndim == 3 + # assert reward.shape == (horizon, num_slices, 1) + # # We dont use `batch["next", "done"]` since it only indicates the end of an + # # episode, but not the end of the trajectory of an episode. + # # Neither does `batch["next", "terminated"]` + # done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) + # mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) + # return obs, action, next_obses, reward, mask, done, idxs, weights - batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample() + # batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample() - obs, action, next_obses, reward, mask, done, idxs, weights = process_batch( - batch, self.cfg.horizon, num_slices - ) + # obs, action, next_obses, reward, mask, done, idxs, weights = process_batch( + # batch, self.cfg.horizon, num_slices + # ) # Sample from demonstration dataset - if demo_batch_size > 0: - demo_batch = demo_buffer.sample(demo_batch_size) - ( - demo_obs, - demo_action, - demo_next_obses, - demo_reward, - demo_mask, - demo_done, - demo_idxs, - demo_weights, - ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices) + # if demo_batch_size > 0: + # demo_batch = demo_buffer.sample(demo_batch_size) + # ( + # demo_obs, + # demo_action, + # demo_next_obses, + # demo_reward, + # demo_mask, + # demo_done, + # demo_idxs, + # demo_weights, + # ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices) - if isinstance(obs, dict): - obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs} - next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses} - else: - obs = torch.cat([obs, demo_obs]) - next_obses = torch.cat([next_obses, demo_next_obses], dim=1) - action = torch.cat([action, demo_action], dim=1) - reward = torch.cat([reward, demo_reward], dim=1) - mask = torch.cat([mask, demo_mask], dim=1) - done = torch.cat([done, demo_done], dim=1) - idxs = torch.cat([idxs, demo_idxs]) - weights = torch.cat([weights, demo_weights]) + # if isinstance(obs, dict): + # obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs} + # next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses} + # else: + # obs = torch.cat([obs, demo_obs]) + # next_obses = torch.cat([next_obses, demo_next_obses], dim=1) + # action = torch.cat([action, demo_action], dim=1) + # reward = torch.cat([reward, demo_reward], dim=1) + # mask = torch.cat([mask, demo_mask], dim=1) + # done = torch.cat([done, demo_done], dim=1) + # idxs = torch.cat([idxs, demo_idxs]) + # weights = torch.cat([weights, demo_weights]) # Apply augmentations aug_tf = h.aug(self.cfg) diff --git a/lerobot/configs/env/simxarm.yaml b/lerobot/configs/env/simxarm.yaml index f79db8f7..843f80c6 100644 --- a/lerobot/configs/env/simxarm.yaml +++ b/lerobot/configs/env/simxarm.yaml @@ -17,7 +17,7 @@ env: from_pixels: True pixels_only: False image_size: 84 - action_repeat: 2 + # action_repeat: 2 # we can remove if policy has n_action_steps=2 episode_length: 25 fps: ${fps} diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index ff0e6b04..5d5d8b62 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -1,6 +1,6 @@ # @package _global_ -n_action_steps: 1 +n_action_steps: 2 n_obs_steps: 1 policy: diff --git a/sbatch.sh b/sbatch.sh index cb5b285a..c08f7055 100644 --- a/sbatch.sh +++ b/sbatch.sh @@ -6,7 +6,7 @@ #SBATCH --time=2-00:00:00 #SBATCH --output=/home/rcadene/slurm/%j.out #SBATCH --error=/home/rcadene/slurm/%j.err -#SBATCH --qos=medium +#SBATCH --qos=low #SBATCH --mail-user=re.cadene@gmail.com #SBATCH --mail-type=ALL @@ -20,4 +20,6 @@ source ~/.bashrc #conda activate fowm conda activate lerobot +export DATA_DIR="data" + srun $CMD