WIP tdmpc

This commit is contained in:
Cadene 2024-04-05 13:40:31 +00:00
parent ab3cd3a7ba
commit f56b1a0e16
4 changed files with 121 additions and 94 deletions

View File

@ -1,6 +1,7 @@
# ruff: noqa: N806
import time
from collections import deque
from copy import deepcopy
import einops
@ -9,7 +10,7 @@ import torch
import torch.nn as nn
import lerobot.common.policies.tdmpc.helper as h
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.utils import populate_queues
from lerobot.common.utils import get_safe_torch_device
FIRST_FRAME = 0
@ -87,16 +88,18 @@ class TOLD(nn.Module):
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
class TDMPCPolicy(AbstractPolicy):
class TDMPCPolicy(nn.Module):
"""Implementation of TD-MPC learning + inference."""
name = "tdmpc"
def __init__(self, cfg, device):
super().__init__(None)
def __init__(self, cfg, n_obs_steps, n_action_steps, device):
super().__init__()
self.action_dim = cfg.action_dim
self.cfg = cfg
self.n_obs_steps = n_obs_steps
self.n_action_steps = n_action_steps
self.device = get_safe_torch_device(device)
self.std = h.linear_schedule(cfg.std_schedule, 0)
self.model = TOLD(cfg)
@ -128,20 +131,42 @@ class TDMPCPolicy(AbstractPolicy):
self.model.load_state_dict(d["model"])
self.model_target.load_state_dict(d["model_target"])
@torch.no_grad()
def select_actions(self, observation, step_count):
if observation["image"].shape[0] != 1:
raise NotImplementedError("Batch size > 1 not handled")
t0 = step_count.item() == 0
obs = {
# TODO(rcadene): remove contiguous hack...
"rgb": observation["image"].contiguous(),
"state": observation["state"].contiguous(),
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._queues = {
"observation.image": deque(maxlen=self.n_obs_steps),
"observation.state": deque(maxlen=self.n_obs_steps),
"action": deque(maxlen=self.n_action_steps),
}
# Note: unsqueeze needed because `act` still uses non-batch logic.
action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0)
@torch.no_grad()
def select_action(self, batch, step):
assert "observation.image" in batch
assert "observation.state" in batch
assert len(batch) == 2
self._queues = populate_queues(self._queues, batch)
t0 = step == 0
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
actions = []
batch_size = batch["observation.image."].shape[0]
for i in range(batch_size):
obs = {
"rgb": batch["observation.image"][[i]],
"state": batch["observation.state"][[i]],
}
# Note: unsqueeze needed because `act` still uses non-batch logic.
action = self.act(obs, t0=t0, step=self.step)
actions.append(action)
action = torch.stack(actions)
action = self._queues["action"].popleft()
return action
@torch.no_grad()
@ -293,97 +318,97 @@ class TDMPCPolicy(AbstractPolicy):
td_target = reward + self.cfg.discount * mask * next_v
return td_target
def update(self, replay_buffer, step, demo_buffer=None):
def forward(self, batch, step):
"""Main update function. Corresponds to one iteration of the model learning."""
start_time = time.time()
num_slices = self.cfg.batch_size
batch_size = self.cfg.horizon * num_slices
# num_slices = self.cfg.batch_size
# batch_size = self.cfg.horizon * num_slices
if demo_buffer is None:
demo_batch_size = 0
else:
# Update oversampling ratio
demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
demo_num_slices = int(demo_pc_batch * self.batch_size)
demo_batch_size = self.cfg.horizon * demo_num_slices
batch_size -= demo_batch_size
num_slices -= demo_num_slices
replay_buffer._sampler.num_slices = num_slices
demo_buffer._sampler.num_slices = demo_num_slices
# if demo_buffer is None:
# demo_batch_size = 0
# else:
# # Update oversampling ratio
# demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
# demo_num_slices = int(demo_pc_batch * self.batch_size)
# demo_batch_size = self.cfg.horizon * demo_num_slices
# batch_size -= demo_batch_size
# num_slices -= demo_num_slices
# replay_buffer._sampler.num_slices = num_slices
# demo_buffer._sampler.num_slices = demo_num_slices
assert demo_batch_size % self.cfg.horizon == 0
assert demo_batch_size % demo_num_slices == 0
# assert demo_batch_size % self.cfg.horizon == 0
# assert demo_batch_size % demo_num_slices == 0
assert batch_size % self.cfg.horizon == 0
assert batch_size % num_slices == 0
# assert batch_size % self.cfg.horizon == 0
# assert batch_size % num_slices == 0
# Sample from interaction dataset
# # Sample from interaction dataset
def process_batch(batch, horizon, num_slices):
# trajectory t = 256, horizon h = 5
# (t h) ... -> h t ...
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
# def process_batch(batch, horizon, num_slices):
# # trajectory t = 256, horizon h = 5
# # (t h) ... -> h t ...
# batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
obs = {
"rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
"state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
}
action = batch["action"].to(self.device, non_blocking=True)
next_obses = {
"rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
"state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
}
reward = batch["next", "reward"].to(self.device, non_blocking=True)
# obs = {
# "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
# "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
# }
# action = batch["action"].to(self.device, non_blocking=True)
# next_obses = {
# "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
# "state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
# }
# reward = batch["next", "reward"].to(self.device, non_blocking=True)
idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
# idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
# weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
# TODO(rcadene): rearrange directly in offline dataset
if reward.ndim == 2:
reward = einops.rearrange(reward, "h t -> h t 1")
# # TODO(rcadene): rearrange directly in offline dataset
# if reward.ndim == 2:
# reward = einops.rearrange(reward, "h t -> h t 1")
assert reward.ndim == 3
assert reward.shape == (horizon, num_slices, 1)
# We dont use `batch["next", "done"]` since it only indicates the end of an
# episode, but not the end of the trajectory of an episode.
# Neither does `batch["next", "terminated"]`
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
return obs, action, next_obses, reward, mask, done, idxs, weights
# assert reward.ndim == 3
# assert reward.shape == (horizon, num_slices, 1)
# # We dont use `batch["next", "done"]` since it only indicates the end of an
# # episode, but not the end of the trajectory of an episode.
# # Neither does `batch["next", "terminated"]`
# done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
# mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
# return obs, action, next_obses, reward, mask, done, idxs, weights
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
# batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
batch, self.cfg.horizon, num_slices
)
# obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
# batch, self.cfg.horizon, num_slices
# )
# Sample from demonstration dataset
if demo_batch_size > 0:
demo_batch = demo_buffer.sample(demo_batch_size)
(
demo_obs,
demo_action,
demo_next_obses,
demo_reward,
demo_mask,
demo_done,
demo_idxs,
demo_weights,
) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
# if demo_batch_size > 0:
# demo_batch = demo_buffer.sample(demo_batch_size)
# (
# demo_obs,
# demo_action,
# demo_next_obses,
# demo_reward,
# demo_mask,
# demo_done,
# demo_idxs,
# demo_weights,
# ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
if isinstance(obs, dict):
obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
else:
obs = torch.cat([obs, demo_obs])
next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
action = torch.cat([action, demo_action], dim=1)
reward = torch.cat([reward, demo_reward], dim=1)
mask = torch.cat([mask, demo_mask], dim=1)
done = torch.cat([done, demo_done], dim=1)
idxs = torch.cat([idxs, demo_idxs])
weights = torch.cat([weights, demo_weights])
# if isinstance(obs, dict):
# obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
# next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
# else:
# obs = torch.cat([obs, demo_obs])
# next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
# action = torch.cat([action, demo_action], dim=1)
# reward = torch.cat([reward, demo_reward], dim=1)
# mask = torch.cat([mask, demo_mask], dim=1)
# done = torch.cat([done, demo_done], dim=1)
# idxs = torch.cat([idxs, demo_idxs])
# weights = torch.cat([weights, demo_weights])
# Apply augmentations
aug_tf = h.aug(self.cfg)

View File

@ -17,7 +17,7 @@ env:
from_pixels: True
pixels_only: False
image_size: 84
action_repeat: 2
# action_repeat: 2 # we can remove if policy has n_action_steps=2
episode_length: 25
fps: ${fps}

View File

@ -1,6 +1,6 @@
# @package _global_
n_action_steps: 1
n_action_steps: 2
n_obs_steps: 1
policy:

View File

@ -6,7 +6,7 @@
#SBATCH --time=2-00:00:00
#SBATCH --output=/home/rcadene/slurm/%j.out
#SBATCH --error=/home/rcadene/slurm/%j.err
#SBATCH --qos=medium
#SBATCH --qos=low
#SBATCH --mail-user=re.cadene@gmail.com
#SBATCH --mail-type=ALL
@ -20,4 +20,6 @@ source ~/.bashrc
#conda activate fowm
conda activate lerobot
export DATA_DIR="data"
srun $CMD