backup wip

This commit is contained in:
Alexander Soare 2024-03-20 09:17:02 +00:00
parent 5332766a82
commit 32e3f71dd1
5 changed files with 75 additions and 26 deletions

View File

@ -190,11 +190,10 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
# run sampling # run sampling
nsample = self.conditional_sample( nsample = self.conditional_sample(
cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond, **self.kwargs cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond
) )
action_pred = nsample[..., :action_dim] action_pred = nsample[..., :action_dim]
# get action # get action
start = n_obs_steps - 1 start = n_obs_steps - 1
end = start + self.n_action_steps end = start + self.n_action_steps

View File

@ -1,15 +1,40 @@
import copy import copy
from typing import Dict, Tuple, Union from typing import Dict, Optional, Tuple, Union
import timm
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision import torchvision
from robomimic.models.base_nets import SpatialSoftmax
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
class RgbEncoder(nn.Module):
"""Following `VisualCore` from Robomimic 0.2.0."""
def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32):
"""
input_shape: channel-first input shape (C, H, W)
resnet_name: a timm model name.
pretrained: whether to use timm pretrained weights.
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
"""
super().__init__()
self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="")
# self.backbone = ResNet18Conv(input_channel=input_shape[0])
# Figure out the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2)
def forward(self, x):
return self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))
class MultiImageObsEncoder(ModuleAttrMixin): class MultiImageObsEncoder(ModuleAttrMixin):
def __init__( def __init__(
self, self,
@ -24,7 +49,7 @@ class MultiImageObsEncoder(ModuleAttrMixin):
share_rgb_model: bool = False, share_rgb_model: bool = False,
# renormalize rgb input with imagenet normalization # renormalize rgb input with imagenet normalization
# assuming input in [0,1] # assuming input in [0,1]
imagenet_norm: bool = False, norm_mean_std: Optional[tuple[float, float]] = None,
): ):
""" """
Assumes rgb input: B,C,H,W Assumes rgb input: B,C,H,W
@ -98,10 +123,9 @@ class MultiImageObsEncoder(ModuleAttrMixin):
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w)) this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
# configure normalizer # configure normalizer
this_normalizer = nn.Identity() this_normalizer = nn.Identity()
if imagenet_norm: if norm_mean_std is not None:
# TODO(rcadene): move normalizer to dataset and env
this_normalizer = torchvision.transforms.Normalize( this_normalizer = torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] mean=norm_mean_std[0], std=norm_mean_std[1]
) )
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer) this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
@ -124,6 +148,17 @@ class MultiImageObsEncoder(ModuleAttrMixin):
def forward(self, obs_dict): def forward(self, obs_dict):
batch_size = None batch_size = None
features = [] features = []
# process lowdim input
for key in self.low_dim_keys:
data = obs_dict[key]
if batch_size is None:
batch_size = data.shape[0]
else:
assert batch_size == data.shape[0]
assert data.shape[1:] == self.key_shape_map[key]
features.append(data)
# process rgb input # process rgb input
if self.share_rgb_model: if self.share_rgb_model:
# pass all rgb obs to rgb model # pass all rgb obs to rgb model
@ -147,6 +182,7 @@ class MultiImageObsEncoder(ModuleAttrMixin):
feature = torch.moveaxis(feature, 0, 1) feature = torch.moveaxis(feature, 0, 1)
# (B,N*D) # (B,N*D)
feature = feature.reshape(batch_size, -1) feature = feature.reshape(batch_size, -1)
# feature = torch.nn.functional.relu(feature) # TODO: make optional
features.append(feature) features.append(feature)
else: else:
# run each rgb obs to independent models # run each rgb obs to independent models
@ -159,18 +195,9 @@ class MultiImageObsEncoder(ModuleAttrMixin):
assert img.shape[1:] == self.key_shape_map[key] assert img.shape[1:] == self.key_shape_map[key]
img = self.key_transform_map[key](img) img = self.key_transform_map[key](img)
feature = self.key_model_map[key](img) feature = self.key_model_map[key](img)
# feature = torch.nn.functional.relu(feature) # TODO: make optional
features.append(feature) features.append(feature)
# process lowdim input
for key in self.low_dim_keys:
data = obs_dict[key]
if batch_size is None:
batch_size = data.shape[0]
else:
assert batch_size == data.shape[0]
assert data.shape[1:] == self.key_shape_map[key]
features.append(data)
# concatenate all features # concatenate all features
result = torch.cat(features, dim=-1) result = torch.cat(features, dim=-1)
return result return result

View File

@ -7,7 +7,7 @@ import torch
from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
class DiffusionPolicy(AbstractPolicy): class DiffusionPolicy(AbstractPolicy):
@ -38,6 +38,10 @@ class DiffusionPolicy(AbstractPolicy):
self.cfg = cfg self.cfg = cfg
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape)
if cfg_obs_encoder.crop_shape is not None:
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
rgb_model = hydra.utils.instantiate(cfg_rgb_model) rgb_model = hydra.utils.instantiate(cfg_rgb_model)
obs_encoder = MultiImageObsEncoder( obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model, rgb_model=rgb_model,

View File

@ -40,4 +40,23 @@ def make_policy(cfg):
raise NotImplementedError() raise NotImplementedError()
policy.load(cfg.policy.pretrained_model_path) policy.load(cfg.policy.pretrained_model_path)
# import torch
# loaded = torch.load('/home/alexander/Downloads/dp_ema.pth')
# aligned = {}
# their_prefix = "obs_encoder.obs_nets.image.backbone"
# our_prefix = "obs_encoder.key_model_map.image.backbone"
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
# their_prefix = "obs_encoder.obs_nets.image.pool"
# our_prefix = "obs_encoder.key_model_map.image.pool"
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
# their_prefix = "obs_encoder.obs_nets.image.nets.3"
# our_prefix = "obs_encoder.key_model_map.image.out"
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
# aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')})
# missing_keys, unexpected_keys = policy.diffusion.load_state_dict(aligned, strict=False)
# assert all('_dummy_variable' in k for k in missing_keys)
# assert len(unexpected_keys) == 0
return policy return policy

View File

@ -42,8 +42,8 @@ policy:
num_inference_steps: 100 num_inference_steps: 100
obs_as_global_cond: ${obs_as_global_cond} obs_as_global_cond: ${obs_as_global_cond}
# crop_shape: null # crop_shape: null
diffusion_step_embed_dim: 256 # before 128 diffusion_step_embed_dim: 128
down_dims: [256, 512, 1024] # before [512, 1024, 2048] down_dims: [512, 1024, 2048]
kernel_size: 5 kernel_size: 5
n_groups: 8 n_groups: 8
cond_predict_scale: True cond_predict_scale: True
@ -76,17 +76,17 @@ noise_scheduler:
obs_encoder: obs_encoder:
shape_meta: ${shape_meta} shape_meta: ${shape_meta}
# resize_shape: null # resize_shape: null
# crop_shape: [76, 76] crop_shape: [84, 84]
# constant center crop # constant center crop
# random_crop: True random_crop: True
use_group_norm: True use_group_norm: True
share_rgb_model: False share_rgb_model: False
imagenet_norm: True norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
rgb_model: rgb_model:
_target_: lerobot.common.policies.diffusion.pytorch_utils.get_resnet model_name: resnet18
name: resnet18 pretrained: false
weights: null num_keypoints: 32
ema: ema:
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel _target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel