diff --git a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py index b759802e..c5b00d94 100644 --- a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py +++ b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py @@ -190,11 +190,10 @@ class DiffusionUnetImagePolicy(BaseImagePolicy): # run sampling nsample = self.conditional_sample( - cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond, **self.kwargs + cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond ) action_pred = nsample[..., :action_dim] - # get action start = n_obs_steps - 1 end = start + self.n_action_steps diff --git a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py index 94dc6f49..17252c1c 100644 --- a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py +++ b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py @@ -1,15 +1,40 @@ import copy -from typing import Dict, Tuple, Union +from typing import Dict, Optional, Tuple, Union +import timm import torch import torch.nn as nn import torchvision +from robomimic.models.base_nets import SpatialSoftmax from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules +class RgbEncoder(nn.Module): + """Following `VisualCore` from Robomimic 0.2.0.""" + + def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32): + """ + input_shape: channel-first input shape (C, H, W) + resnet_name: a timm model name. + pretrained: whether to use timm pretrained weights. + num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image). + """ + super().__init__() + self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="") + # self.backbone = ResNet18Conv(input_channel=input_shape[0]) + # Figure out the feature map shape. + with torch.inference_mode(): + feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:]) + self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints) + self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2) + + def forward(self, x): + return self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)) + + class MultiImageObsEncoder(ModuleAttrMixin): def __init__( self, @@ -24,7 +49,7 @@ class MultiImageObsEncoder(ModuleAttrMixin): share_rgb_model: bool = False, # renormalize rgb input with imagenet normalization # assuming input in [0,1] - imagenet_norm: bool = False, + norm_mean_std: Optional[tuple[float, float]] = None, ): """ Assumes rgb input: B,C,H,W @@ -98,10 +123,9 @@ class MultiImageObsEncoder(ModuleAttrMixin): this_normalizer = torchvision.transforms.CenterCrop(size=(h, w)) # configure normalizer this_normalizer = nn.Identity() - if imagenet_norm: - # TODO(rcadene): move normalizer to dataset and env + if norm_mean_std is not None: this_normalizer = torchvision.transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + mean=norm_mean_std[0], std=norm_mean_std[1] ) this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer) @@ -124,6 +148,17 @@ class MultiImageObsEncoder(ModuleAttrMixin): def forward(self, obs_dict): batch_size = None features = [] + + # process lowdim input + for key in self.low_dim_keys: + data = obs_dict[key] + if batch_size is None: + batch_size = data.shape[0] + else: + assert batch_size == data.shape[0] + assert data.shape[1:] == self.key_shape_map[key] + features.append(data) + # process rgb input if self.share_rgb_model: # pass all rgb obs to rgb model @@ -147,6 +182,7 @@ class MultiImageObsEncoder(ModuleAttrMixin): feature = torch.moveaxis(feature, 0, 1) # (B,N*D) feature = feature.reshape(batch_size, -1) + # feature = torch.nn.functional.relu(feature) # TODO: make optional features.append(feature) else: # run each rgb obs to independent models @@ -159,18 +195,9 @@ class MultiImageObsEncoder(ModuleAttrMixin): assert img.shape[1:] == self.key_shape_map[key] img = self.key_transform_map[key](img) feature = self.key_model_map[key](img) + # feature = torch.nn.functional.relu(feature) # TODO: make optional features.append(feature) - # process lowdim input - for key in self.low_dim_keys: - data = obs_dict[key] - if batch_size is None: - batch_size = data.shape[0] - else: - assert batch_size == data.shape[0] - assert data.shape[1:] == self.key_shape_map[key] - features.append(data) - # concatenate all features result = torch.cat(features, dim=-1) return result diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index 2c47f172..f68ffb8e 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -7,7 +7,7 @@ import torch from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler -from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder +from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder class DiffusionPolicy(AbstractPolicy): @@ -38,6 +38,10 @@ class DiffusionPolicy(AbstractPolicy): self.cfg = cfg noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) + rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape) + if cfg_obs_encoder.crop_shape is not None: + rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape + rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model) rgb_model = hydra.utils.instantiate(cfg_rgb_model) obs_encoder = MultiImageObsEncoder( rgb_model=rgb_model, diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 085baab5..7961beed 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -40,4 +40,23 @@ def make_policy(cfg): raise NotImplementedError() policy.load(cfg.policy.pretrained_model_path) + # import torch + # loaded = torch.load('/home/alexander/Downloads/dp_ema.pth') + # aligned = {} + + # their_prefix = "obs_encoder.obs_nets.image.backbone" + # our_prefix = "obs_encoder.key_model_map.image.backbone" + # aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) + # their_prefix = "obs_encoder.obs_nets.image.pool" + # our_prefix = "obs_encoder.key_model_map.image.pool" + # aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) + # their_prefix = "obs_encoder.obs_nets.image.nets.3" + # our_prefix = "obs_encoder.key_model_map.image.out" + # aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) + + # aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')}) + # missing_keys, unexpected_keys = policy.diffusion.load_state_dict(aligned, strict=False) + # assert all('_dummy_variable' in k for k in missing_keys) + # assert len(unexpected_keys) == 0 + return policy diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 0dae5056..2b63f7e1 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -42,8 +42,8 @@ policy: num_inference_steps: 100 obs_as_global_cond: ${obs_as_global_cond} # crop_shape: null - diffusion_step_embed_dim: 256 # before 128 - down_dims: [256, 512, 1024] # before [512, 1024, 2048] + diffusion_step_embed_dim: 128 + down_dims: [512, 1024, 2048] kernel_size: 5 n_groups: 8 cond_predict_scale: True @@ -76,17 +76,17 @@ noise_scheduler: obs_encoder: shape_meta: ${shape_meta} # resize_shape: null - # crop_shape: [76, 76] + crop_shape: [84, 84] # constant center crop - # random_crop: True + random_crop: True use_group_norm: True share_rgb_model: False - imagenet_norm: True + norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs) rgb_model: - _target_: lerobot.common.policies.diffusion.pytorch_utils.get_resnet - name: resnet18 - weights: null + model_name: resnet18 + pretrained: false + num_keypoints: 32 ema: _target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel