WIP tdmpc

This commit is contained in:
Cadene 2024-04-05 13:40:31 +00:00
parent ab3cd3a7ba
commit f56b1a0e16
4 changed files with 121 additions and 94 deletions

View File

@ -1,6 +1,7 @@
# ruff: noqa: N806 # ruff: noqa: N806
import time import time
from collections import deque
from copy import deepcopy from copy import deepcopy
import einops import einops
@ -9,7 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import lerobot.common.policies.tdmpc.helper as h import lerobot.common.policies.tdmpc.helper as h
from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.utils import populate_queues
from lerobot.common.utils import get_safe_torch_device from lerobot.common.utils import get_safe_torch_device
FIRST_FRAME = 0 FIRST_FRAME = 0
@ -87,16 +88,18 @@ class TOLD(nn.Module):
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2 return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
class TDMPCPolicy(AbstractPolicy): class TDMPCPolicy(nn.Module):
"""Implementation of TD-MPC learning + inference.""" """Implementation of TD-MPC learning + inference."""
name = "tdmpc" name = "tdmpc"
def __init__(self, cfg, device): def __init__(self, cfg, n_obs_steps, n_action_steps, device):
super().__init__(None) super().__init__()
self.action_dim = cfg.action_dim self.action_dim = cfg.action_dim
self.cfg = cfg self.cfg = cfg
self.n_obs_steps = n_obs_steps
self.n_action_steps = n_action_steps
self.device = get_safe_torch_device(device) self.device = get_safe_torch_device(device)
self.std = h.linear_schedule(cfg.std_schedule, 0) self.std = h.linear_schedule(cfg.std_schedule, 0)
self.model = TOLD(cfg) self.model = TOLD(cfg)
@ -128,20 +131,42 @@ class TDMPCPolicy(AbstractPolicy):
self.model.load_state_dict(d["model"]) self.model.load_state_dict(d["model"])
self.model_target.load_state_dict(d["model_target"]) self.model_target.load_state_dict(d["model_target"])
@torch.no_grad() def reset(self):
def select_actions(self, observation, step_count): """
if observation["image"].shape[0] != 1: Clear observation and action queues. Should be called on `env.reset()`
raise NotImplementedError("Batch size > 1 not handled") """
self._queues = {
t0 = step_count.item() == 0 "observation.image": deque(maxlen=self.n_obs_steps),
"observation.state": deque(maxlen=self.n_obs_steps),
obs = { "action": deque(maxlen=self.n_action_steps),
# TODO(rcadene): remove contiguous hack...
"rgb": observation["image"].contiguous(),
"state": observation["state"].contiguous(),
} }
# Note: unsqueeze needed because `act` still uses non-batch logic.
action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0) @torch.no_grad()
def select_action(self, batch, step):
assert "observation.image" in batch
assert "observation.state" in batch
assert len(batch) == 2
self._queues = populate_queues(self._queues, batch)
t0 = step == 0
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
actions = []
batch_size = batch["observation.image."].shape[0]
for i in range(batch_size):
obs = {
"rgb": batch["observation.image"][[i]],
"state": batch["observation.state"][[i]],
}
# Note: unsqueeze needed because `act` still uses non-batch logic.
action = self.act(obs, t0=t0, step=self.step)
actions.append(action)
action = torch.stack(actions)
action = self._queues["action"].popleft()
return action return action
@torch.no_grad() @torch.no_grad()
@ -293,97 +318,97 @@ class TDMPCPolicy(AbstractPolicy):
td_target = reward + self.cfg.discount * mask * next_v td_target = reward + self.cfg.discount * mask * next_v
return td_target return td_target
def update(self, replay_buffer, step, demo_buffer=None): def forward(self, batch, step):
"""Main update function. Corresponds to one iteration of the model learning.""" """Main update function. Corresponds to one iteration of the model learning."""
start_time = time.time() start_time = time.time()
num_slices = self.cfg.batch_size # num_slices = self.cfg.batch_size
batch_size = self.cfg.horizon * num_slices # batch_size = self.cfg.horizon * num_slices
if demo_buffer is None: # if demo_buffer is None:
demo_batch_size = 0 # demo_batch_size = 0
else: # else:
# Update oversampling ratio # # Update oversampling ratio
demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step) # demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
demo_num_slices = int(demo_pc_batch * self.batch_size) # demo_num_slices = int(demo_pc_batch * self.batch_size)
demo_batch_size = self.cfg.horizon * demo_num_slices # demo_batch_size = self.cfg.horizon * demo_num_slices
batch_size -= demo_batch_size # batch_size -= demo_batch_size
num_slices -= demo_num_slices # num_slices -= demo_num_slices
replay_buffer._sampler.num_slices = num_slices # replay_buffer._sampler.num_slices = num_slices
demo_buffer._sampler.num_slices = demo_num_slices # demo_buffer._sampler.num_slices = demo_num_slices
assert demo_batch_size % self.cfg.horizon == 0 # assert demo_batch_size % self.cfg.horizon == 0
assert demo_batch_size % demo_num_slices == 0 # assert demo_batch_size % demo_num_slices == 0
assert batch_size % self.cfg.horizon == 0 # assert batch_size % self.cfg.horizon == 0
assert batch_size % num_slices == 0 # assert batch_size % num_slices == 0
# Sample from interaction dataset # # Sample from interaction dataset
def process_batch(batch, horizon, num_slices): # def process_batch(batch, horizon, num_slices):
# trajectory t = 256, horizon h = 5 # # trajectory t = 256, horizon h = 5
# (t h) ... -> h t ... # # (t h) ... -> h t ...
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous() # batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
obs = { # obs = {
"rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True), # "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
"state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True), # "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
} # }
action = batch["action"].to(self.device, non_blocking=True) # action = batch["action"].to(self.device, non_blocking=True)
next_obses = { # next_obses = {
"rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True), # "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
"state": batch["next", "observation", "state"].to(self.device, non_blocking=True), # "state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
} # }
reward = batch["next", "reward"].to(self.device, non_blocking=True) # reward = batch["next", "reward"].to(self.device, non_blocking=True)
idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True) # idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True) # weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
# TODO(rcadene): rearrange directly in offline dataset # # TODO(rcadene): rearrange directly in offline dataset
if reward.ndim == 2: # if reward.ndim == 2:
reward = einops.rearrange(reward, "h t -> h t 1") # reward = einops.rearrange(reward, "h t -> h t 1")
assert reward.ndim == 3 # assert reward.ndim == 3
assert reward.shape == (horizon, num_slices, 1) # assert reward.shape == (horizon, num_slices, 1)
# We dont use `batch["next", "done"]` since it only indicates the end of an # # We dont use `batch["next", "done"]` since it only indicates the end of an
# episode, but not the end of the trajectory of an episode. # # episode, but not the end of the trajectory of an episode.
# Neither does `batch["next", "terminated"]` # # Neither does `batch["next", "terminated"]`
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) # done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) # mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
return obs, action, next_obses, reward, mask, done, idxs, weights # return obs, action, next_obses, reward, mask, done, idxs, weights
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample() # batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
obs, action, next_obses, reward, mask, done, idxs, weights = process_batch( # obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
batch, self.cfg.horizon, num_slices # batch, self.cfg.horizon, num_slices
) # )
# Sample from demonstration dataset # Sample from demonstration dataset
if demo_batch_size > 0: # if demo_batch_size > 0:
demo_batch = demo_buffer.sample(demo_batch_size) # demo_batch = demo_buffer.sample(demo_batch_size)
( # (
demo_obs, # demo_obs,
demo_action, # demo_action,
demo_next_obses, # demo_next_obses,
demo_reward, # demo_reward,
demo_mask, # demo_mask,
demo_done, # demo_done,
demo_idxs, # demo_idxs,
demo_weights, # demo_weights,
) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices) # ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
if isinstance(obs, dict): # if isinstance(obs, dict):
obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs} # obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses} # next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
else: # else:
obs = torch.cat([obs, demo_obs]) # obs = torch.cat([obs, demo_obs])
next_obses = torch.cat([next_obses, demo_next_obses], dim=1) # next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
action = torch.cat([action, demo_action], dim=1) # action = torch.cat([action, demo_action], dim=1)
reward = torch.cat([reward, demo_reward], dim=1) # reward = torch.cat([reward, demo_reward], dim=1)
mask = torch.cat([mask, demo_mask], dim=1) # mask = torch.cat([mask, demo_mask], dim=1)
done = torch.cat([done, demo_done], dim=1) # done = torch.cat([done, demo_done], dim=1)
idxs = torch.cat([idxs, demo_idxs]) # idxs = torch.cat([idxs, demo_idxs])
weights = torch.cat([weights, demo_weights]) # weights = torch.cat([weights, demo_weights])
# Apply augmentations # Apply augmentations
aug_tf = h.aug(self.cfg) aug_tf = h.aug(self.cfg)

View File

@ -17,7 +17,7 @@ env:
from_pixels: True from_pixels: True
pixels_only: False pixels_only: False
image_size: 84 image_size: 84
action_repeat: 2 # action_repeat: 2 # we can remove if policy has n_action_steps=2
episode_length: 25 episode_length: 25
fps: ${fps} fps: ${fps}

View File

@ -1,6 +1,6 @@
# @package _global_ # @package _global_
n_action_steps: 1 n_action_steps: 2
n_obs_steps: 1 n_obs_steps: 1
policy: policy:

View File

@ -6,7 +6,7 @@
#SBATCH --time=2-00:00:00 #SBATCH --time=2-00:00:00
#SBATCH --output=/home/rcadene/slurm/%j.out #SBATCH --output=/home/rcadene/slurm/%j.out
#SBATCH --error=/home/rcadene/slurm/%j.err #SBATCH --error=/home/rcadene/slurm/%j.err
#SBATCH --qos=medium #SBATCH --qos=low
#SBATCH --mail-user=re.cadene@gmail.com #SBATCH --mail-user=re.cadene@gmail.com
#SBATCH --mail-type=ALL #SBATCH --mail-type=ALL
@ -20,4 +20,6 @@ source ~/.bashrc
#conda activate fowm #conda activate fowm
conda activate lerobot conda activate lerobot
export DATA_DIR="data"
srun $CMD srun $CMD