revert dp changes, make act and tdmpc batch friendly
This commit is contained in:
parent
09ddd9bf92
commit
88347965c2
|
@ -1,15 +1,20 @@
|
||||||
from abc import abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
class AbstractPolicy(nn.Module):
|
class AbstractPolicy(nn.Module, ABC):
|
||||||
|
"""Base policy which all policies should be derived from.
|
||||||
|
|
||||||
|
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
||||||
|
documentation for more information.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update(self, replay_buffer, step):
|
def update(self, replay_buffer, step):
|
||||||
"""One step of the policy's learning algorithm."""
|
"""One step of the policy's learning algorithm."""
|
||||||
pass
|
|
||||||
|
|
||||||
def save(self, fp):
|
def save(self, fp):
|
||||||
torch.save(self.state_dict(), fp)
|
torch.save(self.state_dict(), fp)
|
||||||
|
@ -24,7 +29,6 @@ class AbstractPolicy(nn.Module):
|
||||||
|
|
||||||
Should return a (batch_size, n_action_steps, *) tensor of actions.
|
Should return a (batch_size, n_action_steps, *) tensor of actions.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
"""Inference step that makes multi-step policies compatible with their single-step environments.
|
"""Inference step that makes multi-step policies compatible with their single-step environments.
|
||||||
|
|
|
@ -153,10 +153,6 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||||
|
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
# TODO(rcadene): remove unsqueeze hack to add bsize=1
|
|
||||||
observation["image", "top"] = observation["image", "top"].unsqueeze(0)
|
|
||||||
# observation["state"] = observation["state"].unsqueeze(0)
|
|
||||||
|
|
||||||
# TODO(rcadene): remove hack
|
# TODO(rcadene): remove hack
|
||||||
# add 1 camera dimension
|
# add 1 camera dimension
|
||||||
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
|
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
|
||||||
|
@ -180,11 +176,8 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||||
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
||||||
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
|
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
|
||||||
|
|
||||||
# remove bsize=1
|
|
||||||
action = action.squeeze(0)
|
|
||||||
|
|
||||||
# take first predicted action or n first actions
|
# take first predicted action or n first actions
|
||||||
action = action[0] if self.n_action_steps == 1 else action[: self.n_action_steps]
|
action = action[: self.n_action_steps]
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def _forward(self, qpos, image, actions=None, is_pad=None):
|
def _forward(self, qpos, image, actions=None, is_pad=None):
|
||||||
|
|
|
@ -1,37 +1,15 @@
|
||||||
import copy
|
import copy
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Tuple, Union
|
||||||
|
|
||||||
import timm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision
|
import torchvision
|
||||||
from robomimic.models.base_nets import SpatialSoftmax
|
|
||||||
|
|
||||||
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
|
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
|
||||||
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
||||||
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
|
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
|
||||||
|
|
||||||
|
|
||||||
class RgbEncoder(nn.Module):
|
|
||||||
"""Following `VisualCore` from Robomimic 0.2.0."""
|
|
||||||
|
|
||||||
def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32):
|
|
||||||
"""
|
|
||||||
resnet_name: a timm model name.
|
|
||||||
pretrained: whether to use timm pretrained weights.
|
|
||||||
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="")
|
|
||||||
# Figure out the feature map shape.
|
|
||||||
with torch.inference_mode():
|
|
||||||
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
|
|
||||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return torch.flatten(self.pool(self.backbone(x)), start_dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiImageObsEncoder(ModuleAttrMixin):
|
class MultiImageObsEncoder(ModuleAttrMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -46,7 +24,7 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
||||||
share_rgb_model: bool = False,
|
share_rgb_model: bool = False,
|
||||||
# renormalize rgb input with imagenet normalization
|
# renormalize rgb input with imagenet normalization
|
||||||
# assuming input in [0,1]
|
# assuming input in [0,1]
|
||||||
norm_mean_std: Optional[tuple[float, float]] = None,
|
imagenet_norm: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Assumes rgb input: B,C,H,W
|
Assumes rgb input: B,C,H,W
|
||||||
|
@ -120,9 +98,10 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
||||||
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
|
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
|
||||||
# configure normalizer
|
# configure normalizer
|
||||||
this_normalizer = nn.Identity()
|
this_normalizer = nn.Identity()
|
||||||
if norm_mean_std is not None:
|
if imagenet_norm:
|
||||||
|
# TODO(rcadene): move normalizer to dataset and env
|
||||||
this_normalizer = torchvision.transforms.Normalize(
|
this_normalizer = torchvision.transforms.Normalize(
|
||||||
mean=norm_mean_std[0], std=norm_mean_std[1]
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||||
)
|
)
|
||||||
|
|
||||||
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
|
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
|
||||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
||||||
from lerobot.common.policies.abstract import AbstractPolicy
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
|
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPolicy(AbstractPolicy):
|
class DiffusionPolicy(AbstractPolicy):
|
||||||
|
@ -38,7 +38,7 @@ class DiffusionPolicy(AbstractPolicy):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||||
rgb_model = RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg_rgb_model)
|
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
|
||||||
obs_encoder = MultiImageObsEncoder(
|
obs_encoder = MultiImageObsEncoder(
|
||||||
rgb_model=rgb_model,
|
rgb_model=rgb_model,
|
||||||
**cfg_obs_encoder,
|
**cfg_obs_encoder,
|
||||||
|
|
|
@ -128,11 +128,6 @@ class TDMPC(AbstractPolicy):
|
||||||
def select_action(self, observation, step_count):
|
def select_action(self, observation, step_count):
|
||||||
t0 = step_count.item() == 0
|
t0 = step_count.item() == 0
|
||||||
|
|
||||||
# TODO(rcadene): remove unsqueeze hack...
|
|
||||||
if observation["image"].ndim == 3:
|
|
||||||
observation["image"] = observation["image"].unsqueeze(0)
|
|
||||||
observation["state"] = observation["state"].unsqueeze(0)
|
|
||||||
|
|
||||||
obs = {
|
obs = {
|
||||||
# TODO(rcadene): remove contiguous hack...
|
# TODO(rcadene): remove contiguous hack...
|
||||||
"rgb": observation["image"].contiguous(),
|
"rgb": observation["image"].contiguous(),
|
||||||
|
@ -149,7 +144,7 @@ class TDMPC(AbstractPolicy):
|
||||||
if self.cfg.mpc:
|
if self.cfg.mpc:
|
||||||
a = self.plan(z, t0=t0, step=step)
|
a = self.plan(z, t0=t0, step=step)
|
||||||
else:
|
else:
|
||||||
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
|
a = self.model.pi(z, self.cfg.min_std * self.model.training)
|
||||||
return a
|
return a
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
|
@ -42,8 +42,8 @@ policy:
|
||||||
num_inference_steps: 100
|
num_inference_steps: 100
|
||||||
obs_as_global_cond: ${obs_as_global_cond}
|
obs_as_global_cond: ${obs_as_global_cond}
|
||||||
# crop_shape: null
|
# crop_shape: null
|
||||||
diffusion_step_embed_dim: 128
|
diffusion_step_embed_dim: 256 # before 128
|
||||||
down_dims: [512, 1024, 2048]
|
down_dims: [256, 512, 1024] # before [512, 1024, 2048]
|
||||||
kernel_size: 5
|
kernel_size: 5
|
||||||
n_groups: 8
|
n_groups: 8
|
||||||
cond_predict_scale: True
|
cond_predict_scale: True
|
||||||
|
@ -81,12 +81,12 @@ obs_encoder:
|
||||||
# random_crop: True
|
# random_crop: True
|
||||||
use_group_norm: True
|
use_group_norm: True
|
||||||
share_rgb_model: False
|
share_rgb_model: False
|
||||||
norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
|
imagenet_norm: True
|
||||||
|
|
||||||
rgb_model:
|
rgb_model:
|
||||||
model_name: resnet18
|
_target_: lerobot.common.policies.diffusion.pytorch_utils.get_resnet
|
||||||
pretrained: false
|
name: resnet18
|
||||||
num_keypoints: 32
|
weights: null
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
|
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
|
||||||
|
@ -109,13 +109,13 @@ training:
|
||||||
debug: False
|
debug: False
|
||||||
resume: True
|
resume: True
|
||||||
# optimization
|
# optimization
|
||||||
lr_scheduler: cosine
|
# lr_scheduler: cosine
|
||||||
lr_warmup_steps: 500
|
# lr_warmup_steps: 500
|
||||||
num_epochs: 500
|
num_epochs: 8000
|
||||||
# gradient_accumulate_every: 1
|
# gradient_accumulate_every: 1
|
||||||
# EMA destroys performance when used with BatchNorm
|
# EMA destroys performance when used with BatchNorm
|
||||||
# replace BatchNorm with GroupNorm.
|
# replace BatchNorm with GroupNorm.
|
||||||
use_ema: True
|
# use_ema: True
|
||||||
freeze_encoder: False
|
freeze_encoder: False
|
||||||
# training loop control
|
# training loop control
|
||||||
# in epochs
|
# in epochs
|
||||||
|
|
|
@ -135,8 +135,8 @@ def eval(cfg: dict, out_dir=None):
|
||||||
cfg.rollout_batch_size,
|
cfg.rollout_batch_size,
|
||||||
create_env_fn=make_env,
|
create_env_fn=make_env,
|
||||||
create_env_kwargs=[
|
create_env_kwargs=[
|
||||||
{"cfg": cfg, "seed": s, "transform": offline_buffer.transform}
|
{"cfg": cfg, "seed": env_seed, "transform": offline_buffer.transform}
|
||||||
for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -84,6 +84,9 @@ def test_abstract_policy_forward():
|
||||||
self.n_action_steps = n_action_steps
|
self.n_action_steps = n_action_steps
|
||||||
self.n_policy_invocations = 0
|
self.n_policy_invocations = 0
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def select_action(self):
|
def select_action(self):
|
||||||
self.n_policy_invocations += 1
|
self.n_policy_invocations += 1
|
||||||
return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0)
|
return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0)
|
||||||
|
|
Loading…
Reference in New Issue