WIP tdmpc
This commit is contained in:
parent
ab3cd3a7ba
commit
f56b1a0e16
|
@ -1,6 +1,7 @@
|
||||||
# ruff: noqa: N806
|
# ruff: noqa: N806
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
from collections import deque
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
@ -9,7 +10,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import lerobot.common.policies.tdmpc.helper as h
|
import lerobot.common.policies.tdmpc.helper as h
|
||||||
from lerobot.common.policies.abstract import AbstractPolicy
|
from lerobot.common.policies.utils import populate_queues
|
||||||
from lerobot.common.utils import get_safe_torch_device
|
from lerobot.common.utils import get_safe_torch_device
|
||||||
|
|
||||||
FIRST_FRAME = 0
|
FIRST_FRAME = 0
|
||||||
|
@ -87,16 +88,18 @@ class TOLD(nn.Module):
|
||||||
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||||
|
|
||||||
|
|
||||||
class TDMPCPolicy(AbstractPolicy):
|
class TDMPCPolicy(nn.Module):
|
||||||
"""Implementation of TD-MPC learning + inference."""
|
"""Implementation of TD-MPC learning + inference."""
|
||||||
|
|
||||||
name = "tdmpc"
|
name = "tdmpc"
|
||||||
|
|
||||||
def __init__(self, cfg, device):
|
def __init__(self, cfg, n_obs_steps, n_action_steps, device):
|
||||||
super().__init__(None)
|
super().__init__()
|
||||||
self.action_dim = cfg.action_dim
|
self.action_dim = cfg.action_dim
|
||||||
|
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
self.n_obs_steps = n_obs_steps
|
||||||
|
self.n_action_steps = n_action_steps
|
||||||
self.device = get_safe_torch_device(device)
|
self.device = get_safe_torch_device(device)
|
||||||
self.std = h.linear_schedule(cfg.std_schedule, 0)
|
self.std = h.linear_schedule(cfg.std_schedule, 0)
|
||||||
self.model = TOLD(cfg)
|
self.model = TOLD(cfg)
|
||||||
|
@ -128,20 +131,42 @@ class TDMPCPolicy(AbstractPolicy):
|
||||||
self.model.load_state_dict(d["model"])
|
self.model.load_state_dict(d["model"])
|
||||||
self.model_target.load_state_dict(d["model_target"])
|
self.model_target.load_state_dict(d["model_target"])
|
||||||
|
|
||||||
@torch.no_grad()
|
def reset(self):
|
||||||
def select_actions(self, observation, step_count):
|
"""
|
||||||
if observation["image"].shape[0] != 1:
|
Clear observation and action queues. Should be called on `env.reset()`
|
||||||
raise NotImplementedError("Batch size > 1 not handled")
|
"""
|
||||||
|
self._queues = {
|
||||||
t0 = step_count.item() == 0
|
"observation.image": deque(maxlen=self.n_obs_steps),
|
||||||
|
"observation.state": deque(maxlen=self.n_obs_steps),
|
||||||
obs = {
|
"action": deque(maxlen=self.n_action_steps),
|
||||||
# TODO(rcadene): remove contiguous hack...
|
|
||||||
"rgb": observation["image"].contiguous(),
|
|
||||||
"state": observation["state"].contiguous(),
|
|
||||||
}
|
}
|
||||||
# Note: unsqueeze needed because `act` still uses non-batch logic.
|
|
||||||
action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0)
|
@torch.no_grad()
|
||||||
|
def select_action(self, batch, step):
|
||||||
|
assert "observation.image" in batch
|
||||||
|
assert "observation.state" in batch
|
||||||
|
assert len(batch) == 2
|
||||||
|
|
||||||
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
|
t0 = step == 0
|
||||||
|
|
||||||
|
if len(self._queues["action"]) == 0:
|
||||||
|
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||||
|
|
||||||
|
actions = []
|
||||||
|
batch_size = batch["observation.image."].shape[0]
|
||||||
|
for i in range(batch_size):
|
||||||
|
obs = {
|
||||||
|
"rgb": batch["observation.image"][[i]],
|
||||||
|
"state": batch["observation.state"][[i]],
|
||||||
|
}
|
||||||
|
# Note: unsqueeze needed because `act` still uses non-batch logic.
|
||||||
|
action = self.act(obs, t0=t0, step=self.step)
|
||||||
|
actions.append(action)
|
||||||
|
action = torch.stack(actions)
|
||||||
|
|
||||||
|
action = self._queues["action"].popleft()
|
||||||
return action
|
return action
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -293,97 +318,97 @@ class TDMPCPolicy(AbstractPolicy):
|
||||||
td_target = reward + self.cfg.discount * mask * next_v
|
td_target = reward + self.cfg.discount * mask * next_v
|
||||||
return td_target
|
return td_target
|
||||||
|
|
||||||
def update(self, replay_buffer, step, demo_buffer=None):
|
def forward(self, batch, step):
|
||||||
"""Main update function. Corresponds to one iteration of the model learning."""
|
"""Main update function. Corresponds to one iteration of the model learning."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
num_slices = self.cfg.batch_size
|
# num_slices = self.cfg.batch_size
|
||||||
batch_size = self.cfg.horizon * num_slices
|
# batch_size = self.cfg.horizon * num_slices
|
||||||
|
|
||||||
if demo_buffer is None:
|
# if demo_buffer is None:
|
||||||
demo_batch_size = 0
|
# demo_batch_size = 0
|
||||||
else:
|
# else:
|
||||||
# Update oversampling ratio
|
# # Update oversampling ratio
|
||||||
demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
|
# demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
|
||||||
demo_num_slices = int(demo_pc_batch * self.batch_size)
|
# demo_num_slices = int(demo_pc_batch * self.batch_size)
|
||||||
demo_batch_size = self.cfg.horizon * demo_num_slices
|
# demo_batch_size = self.cfg.horizon * demo_num_slices
|
||||||
batch_size -= demo_batch_size
|
# batch_size -= demo_batch_size
|
||||||
num_slices -= demo_num_slices
|
# num_slices -= demo_num_slices
|
||||||
replay_buffer._sampler.num_slices = num_slices
|
# replay_buffer._sampler.num_slices = num_slices
|
||||||
demo_buffer._sampler.num_slices = demo_num_slices
|
# demo_buffer._sampler.num_slices = demo_num_slices
|
||||||
|
|
||||||
assert demo_batch_size % self.cfg.horizon == 0
|
# assert demo_batch_size % self.cfg.horizon == 0
|
||||||
assert demo_batch_size % demo_num_slices == 0
|
# assert demo_batch_size % demo_num_slices == 0
|
||||||
|
|
||||||
assert batch_size % self.cfg.horizon == 0
|
# assert batch_size % self.cfg.horizon == 0
|
||||||
assert batch_size % num_slices == 0
|
# assert batch_size % num_slices == 0
|
||||||
|
|
||||||
# Sample from interaction dataset
|
# # Sample from interaction dataset
|
||||||
|
|
||||||
def process_batch(batch, horizon, num_slices):
|
# def process_batch(batch, horizon, num_slices):
|
||||||
# trajectory t = 256, horizon h = 5
|
# # trajectory t = 256, horizon h = 5
|
||||||
# (t h) ... -> h t ...
|
# # (t h) ... -> h t ...
|
||||||
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
|
# batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
|
||||||
|
|
||||||
obs = {
|
# obs = {
|
||||||
"rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
|
# "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
|
||||||
"state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
|
# "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
|
||||||
}
|
# }
|
||||||
action = batch["action"].to(self.device, non_blocking=True)
|
# action = batch["action"].to(self.device, non_blocking=True)
|
||||||
next_obses = {
|
# next_obses = {
|
||||||
"rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
|
# "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
|
||||||
"state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
|
# "state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
|
||||||
}
|
# }
|
||||||
reward = batch["next", "reward"].to(self.device, non_blocking=True)
|
# reward = batch["next", "reward"].to(self.device, non_blocking=True)
|
||||||
|
|
||||||
idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
|
# idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
|
||||||
weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
|
# weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
|
||||||
|
|
||||||
# TODO(rcadene): rearrange directly in offline dataset
|
# # TODO(rcadene): rearrange directly in offline dataset
|
||||||
if reward.ndim == 2:
|
# if reward.ndim == 2:
|
||||||
reward = einops.rearrange(reward, "h t -> h t 1")
|
# reward = einops.rearrange(reward, "h t -> h t 1")
|
||||||
|
|
||||||
assert reward.ndim == 3
|
# assert reward.ndim == 3
|
||||||
assert reward.shape == (horizon, num_slices, 1)
|
# assert reward.shape == (horizon, num_slices, 1)
|
||||||
# We dont use `batch["next", "done"]` since it only indicates the end of an
|
# # We dont use `batch["next", "done"]` since it only indicates the end of an
|
||||||
# episode, but not the end of the trajectory of an episode.
|
# # episode, but not the end of the trajectory of an episode.
|
||||||
# Neither does `batch["next", "terminated"]`
|
# # Neither does `batch["next", "terminated"]`
|
||||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
# done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
# mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||||
return obs, action, next_obses, reward, mask, done, idxs, weights
|
# return obs, action, next_obses, reward, mask, done, idxs, weights
|
||||||
|
|
||||||
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
|
# batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
|
||||||
|
|
||||||
obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
|
# obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
|
||||||
batch, self.cfg.horizon, num_slices
|
# batch, self.cfg.horizon, num_slices
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Sample from demonstration dataset
|
# Sample from demonstration dataset
|
||||||
if demo_batch_size > 0:
|
# if demo_batch_size > 0:
|
||||||
demo_batch = demo_buffer.sample(demo_batch_size)
|
# demo_batch = demo_buffer.sample(demo_batch_size)
|
||||||
(
|
# (
|
||||||
demo_obs,
|
# demo_obs,
|
||||||
demo_action,
|
# demo_action,
|
||||||
demo_next_obses,
|
# demo_next_obses,
|
||||||
demo_reward,
|
# demo_reward,
|
||||||
demo_mask,
|
# demo_mask,
|
||||||
demo_done,
|
# demo_done,
|
||||||
demo_idxs,
|
# demo_idxs,
|
||||||
demo_weights,
|
# demo_weights,
|
||||||
) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
|
# ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
|
||||||
|
|
||||||
if isinstance(obs, dict):
|
# if isinstance(obs, dict):
|
||||||
obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
|
# obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
|
||||||
next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
|
# next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
|
||||||
else:
|
# else:
|
||||||
obs = torch.cat([obs, demo_obs])
|
# obs = torch.cat([obs, demo_obs])
|
||||||
next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
|
# next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
|
||||||
action = torch.cat([action, demo_action], dim=1)
|
# action = torch.cat([action, demo_action], dim=1)
|
||||||
reward = torch.cat([reward, demo_reward], dim=1)
|
# reward = torch.cat([reward, demo_reward], dim=1)
|
||||||
mask = torch.cat([mask, demo_mask], dim=1)
|
# mask = torch.cat([mask, demo_mask], dim=1)
|
||||||
done = torch.cat([done, demo_done], dim=1)
|
# done = torch.cat([done, demo_done], dim=1)
|
||||||
idxs = torch.cat([idxs, demo_idxs])
|
# idxs = torch.cat([idxs, demo_idxs])
|
||||||
weights = torch.cat([weights, demo_weights])
|
# weights = torch.cat([weights, demo_weights])
|
||||||
|
|
||||||
# Apply augmentations
|
# Apply augmentations
|
||||||
aug_tf = h.aug(self.cfg)
|
aug_tf = h.aug(self.cfg)
|
||||||
|
|
|
@ -17,7 +17,7 @@ env:
|
||||||
from_pixels: True
|
from_pixels: True
|
||||||
pixels_only: False
|
pixels_only: False
|
||||||
image_size: 84
|
image_size: 84
|
||||||
action_repeat: 2
|
# action_repeat: 2 # we can remove if policy has n_action_steps=2
|
||||||
episode_length: 25
|
episode_length: 25
|
||||||
fps: ${fps}
|
fps: ${fps}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# @package _global_
|
# @package _global_
|
||||||
|
|
||||||
n_action_steps: 1
|
n_action_steps: 2
|
||||||
n_obs_steps: 1
|
n_obs_steps: 1
|
||||||
|
|
||||||
policy:
|
policy:
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
#SBATCH --time=2-00:00:00
|
#SBATCH --time=2-00:00:00
|
||||||
#SBATCH --output=/home/rcadene/slurm/%j.out
|
#SBATCH --output=/home/rcadene/slurm/%j.out
|
||||||
#SBATCH --error=/home/rcadene/slurm/%j.err
|
#SBATCH --error=/home/rcadene/slurm/%j.err
|
||||||
#SBATCH --qos=medium
|
#SBATCH --qos=low
|
||||||
#SBATCH --mail-user=re.cadene@gmail.com
|
#SBATCH --mail-user=re.cadene@gmail.com
|
||||||
#SBATCH --mail-type=ALL
|
#SBATCH --mail-type=ALL
|
||||||
|
|
||||||
|
@ -20,4 +20,6 @@ source ~/.bashrc
|
||||||
#conda activate fowm
|
#conda activate fowm
|
||||||
conda activate lerobot
|
conda activate lerobot
|
||||||
|
|
||||||
|
export DATA_DIR="data"
|
||||||
|
|
||||||
srun $CMD
|
srun $CMD
|
||||||
|
|
Loading…
Reference in New Issue