revert dp changes, make act and tdmpc batch friendly

This commit is contained in:
Alexander Soare 2024-03-18 19:18:21 +00:00
parent 09ddd9bf92
commit 88347965c2
8 changed files with 32 additions and 58 deletions

View File

@ -1,15 +1,20 @@
from abc import abstractmethod from abc import ABC, abstractmethod
from collections import deque from collections import deque
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
class AbstractPolicy(nn.Module): class AbstractPolicy(nn.Module, ABC):
"""Base policy which all policies should be derived from.
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
documentation for more information.
"""
@abstractmethod @abstractmethod
def update(self, replay_buffer, step): def update(self, replay_buffer, step):
"""One step of the policy's learning algorithm.""" """One step of the policy's learning algorithm."""
pass
def save(self, fp): def save(self, fp):
torch.save(self.state_dict(), fp) torch.save(self.state_dict(), fp)
@ -24,7 +29,6 @@ class AbstractPolicy(nn.Module):
Should return a (batch_size, n_action_steps, *) tensor of actions. Should return a (batch_size, n_action_steps, *) tensor of actions.
""" """
pass
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
"""Inference step that makes multi-step policies compatible with their single-step environments. """Inference step that makes multi-step policies compatible with their single-step environments.

View File

@ -153,10 +153,6 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
self.eval() self.eval()
# TODO(rcadene): remove unsqueeze hack to add bsize=1
observation["image", "top"] = observation["image", "top"].unsqueeze(0)
# observation["state"] = observation["state"].unsqueeze(0)
# TODO(rcadene): remove hack # TODO(rcadene): remove hack
# add 1 camera dimension # add 1 camera dimension
observation["image", "top"] = observation["image", "top"].unsqueeze(1) observation["image", "top"] = observation["image", "top"].unsqueeze(1)
@ -180,11 +176,8 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1) # exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) # raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
# remove bsize=1
action = action.squeeze(0)
# take first predicted action or n first actions # take first predicted action or n first actions
action = action[0] if self.n_action_steps == 1 else action[: self.n_action_steps] action = action[: self.n_action_steps]
return action return action
def _forward(self, qpos, image, actions=None, is_pad=None): def _forward(self, qpos, image, actions=None, is_pad=None):

View File

@ -1,37 +1,15 @@
import copy import copy
from typing import Dict, Optional, Tuple, Union from typing import Dict, Tuple, Union
import timm
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision import torchvision
from robomimic.models.base_nets import SpatialSoftmax
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
class RgbEncoder(nn.Module):
"""Following `VisualCore` from Robomimic 0.2.0."""
def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32):
"""
resnet_name: a timm model name.
pretrained: whether to use timm pretrained weights.
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
"""
super().__init__()
self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="")
# Figure out the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
def forward(self, x):
return torch.flatten(self.pool(self.backbone(x)), start_dim=1)
class MultiImageObsEncoder(ModuleAttrMixin): class MultiImageObsEncoder(ModuleAttrMixin):
def __init__( def __init__(
self, self,
@ -46,7 +24,7 @@ class MultiImageObsEncoder(ModuleAttrMixin):
share_rgb_model: bool = False, share_rgb_model: bool = False,
# renormalize rgb input with imagenet normalization # renormalize rgb input with imagenet normalization
# assuming input in [0,1] # assuming input in [0,1]
norm_mean_std: Optional[tuple[float, float]] = None, imagenet_norm: bool = False,
): ):
""" """
Assumes rgb input: B,C,H,W Assumes rgb input: B,C,H,W
@ -120,9 +98,10 @@ class MultiImageObsEncoder(ModuleAttrMixin):
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w)) this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
# configure normalizer # configure normalizer
this_normalizer = nn.Identity() this_normalizer = nn.Identity()
if norm_mean_std is not None: if imagenet_norm:
# TODO(rcadene): move normalizer to dataset and env
this_normalizer = torchvision.transforms.Normalize( this_normalizer = torchvision.transforms.Normalize(
mean=norm_mean_std[0], std=norm_mean_std[1] mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
) )
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer) this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)

View File

@ -7,7 +7,7 @@ import torch
from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
class DiffusionPolicy(AbstractPolicy): class DiffusionPolicy(AbstractPolicy):
@ -38,7 +38,7 @@ class DiffusionPolicy(AbstractPolicy):
self.cfg = cfg self.cfg = cfg
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
rgb_model = RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg_rgb_model) rgb_model = hydra.utils.instantiate(cfg_rgb_model)
obs_encoder = MultiImageObsEncoder( obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model, rgb_model=rgb_model,
**cfg_obs_encoder, **cfg_obs_encoder,

View File

@ -128,11 +128,6 @@ class TDMPC(AbstractPolicy):
def select_action(self, observation, step_count): def select_action(self, observation, step_count):
t0 = step_count.item() == 0 t0 = step_count.item() == 0
# TODO(rcadene): remove unsqueeze hack...
if observation["image"].ndim == 3:
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)
obs = { obs = {
# TODO(rcadene): remove contiguous hack... # TODO(rcadene): remove contiguous hack...
"rgb": observation["image"].contiguous(), "rgb": observation["image"].contiguous(),
@ -149,7 +144,7 @@ class TDMPC(AbstractPolicy):
if self.cfg.mpc: if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step) a = self.plan(z, t0=t0, step=step)
else: else:
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0) a = self.model.pi(z, self.cfg.min_std * self.model.training)
return a return a
@torch.no_grad() @torch.no_grad()

View File

@ -42,8 +42,8 @@ policy:
num_inference_steps: 100 num_inference_steps: 100
obs_as_global_cond: ${obs_as_global_cond} obs_as_global_cond: ${obs_as_global_cond}
# crop_shape: null # crop_shape: null
diffusion_step_embed_dim: 128 diffusion_step_embed_dim: 256 # before 128
down_dims: [512, 1024, 2048] down_dims: [256, 512, 1024] # before [512, 1024, 2048]
kernel_size: 5 kernel_size: 5
n_groups: 8 n_groups: 8
cond_predict_scale: True cond_predict_scale: True
@ -81,12 +81,12 @@ obs_encoder:
# random_crop: True # random_crop: True
use_group_norm: True use_group_norm: True
share_rgb_model: False share_rgb_model: False
norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs) imagenet_norm: True
rgb_model: rgb_model:
model_name: resnet18 _target_: lerobot.common.policies.diffusion.pytorch_utils.get_resnet
pretrained: false name: resnet18
num_keypoints: 32 weights: null
ema: ema:
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel _target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
@ -109,13 +109,13 @@ training:
debug: False debug: False
resume: True resume: True
# optimization # optimization
lr_scheduler: cosine # lr_scheduler: cosine
lr_warmup_steps: 500 # lr_warmup_steps: 500
num_epochs: 500 num_epochs: 8000
# gradient_accumulate_every: 1 # gradient_accumulate_every: 1
# EMA destroys performance when used with BatchNorm # EMA destroys performance when used with BatchNorm
# replace BatchNorm with GroupNorm. # replace BatchNorm with GroupNorm.
use_ema: True # use_ema: True
freeze_encoder: False freeze_encoder: False
# training loop control # training loop control
# in epochs # in epochs

View File

@ -135,8 +135,8 @@ def eval(cfg: dict, out_dir=None):
cfg.rollout_batch_size, cfg.rollout_batch_size,
create_env_fn=make_env, create_env_fn=make_env,
create_env_kwargs=[ create_env_kwargs=[
{"cfg": cfg, "seed": s, "transform": offline_buffer.transform} {"cfg": cfg, "seed": env_seed, "transform": offline_buffer.transform}
for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
], ],
) )

View File

@ -84,6 +84,9 @@ def test_abstract_policy_forward():
self.n_action_steps = n_action_steps self.n_action_steps = n_action_steps
self.n_policy_invocations = 0 self.n_policy_invocations = 0
def update(self):
pass
def select_action(self): def select_action(self):
self.n_policy_invocations += 1 self.n_policy_invocations += 1
return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0) return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0)