backup wip

This commit is contained in:
Alexander Soare 2024-03-20 09:17:02 +00:00
parent 5332766a82
commit 32e3f71dd1
5 changed files with 75 additions and 26 deletions

View File

@ -190,11 +190,10 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
# run sampling
nsample = self.conditional_sample(
cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond, **self.kwargs
cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond
)
action_pred = nsample[..., :action_dim]
# get action
start = n_obs_steps - 1
end = start + self.n_action_steps

View File

@ -1,15 +1,40 @@
import copy
from typing import Dict, Tuple, Union
from typing import Dict, Optional, Tuple, Union
import timm
import torch
import torch.nn as nn
import torchvision
from robomimic.models.base_nets import SpatialSoftmax
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
class RgbEncoder(nn.Module):
"""Following `VisualCore` from Robomimic 0.2.0."""
def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32):
"""
input_shape: channel-first input shape (C, H, W)
resnet_name: a timm model name.
pretrained: whether to use timm pretrained weights.
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
"""
super().__init__()
self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="")
# self.backbone = ResNet18Conv(input_channel=input_shape[0])
# Figure out the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2)
def forward(self, x):
return self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))
class MultiImageObsEncoder(ModuleAttrMixin):
def __init__(
self,
@ -24,7 +49,7 @@ class MultiImageObsEncoder(ModuleAttrMixin):
share_rgb_model: bool = False,
# renormalize rgb input with imagenet normalization
# assuming input in [0,1]
imagenet_norm: bool = False,
norm_mean_std: Optional[tuple[float, float]] = None,
):
"""
Assumes rgb input: B,C,H,W
@ -98,10 +123,9 @@ class MultiImageObsEncoder(ModuleAttrMixin):
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
# configure normalizer
this_normalizer = nn.Identity()
if imagenet_norm:
# TODO(rcadene): move normalizer to dataset and env
if norm_mean_std is not None:
this_normalizer = torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
mean=norm_mean_std[0], std=norm_mean_std[1]
)
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
@ -124,6 +148,17 @@ class MultiImageObsEncoder(ModuleAttrMixin):
def forward(self, obs_dict):
batch_size = None
features = []
# process lowdim input
for key in self.low_dim_keys:
data = obs_dict[key]
if batch_size is None:
batch_size = data.shape[0]
else:
assert batch_size == data.shape[0]
assert data.shape[1:] == self.key_shape_map[key]
features.append(data)
# process rgb input
if self.share_rgb_model:
# pass all rgb obs to rgb model
@ -147,6 +182,7 @@ class MultiImageObsEncoder(ModuleAttrMixin):
feature = torch.moveaxis(feature, 0, 1)
# (B,N*D)
feature = feature.reshape(batch_size, -1)
# feature = torch.nn.functional.relu(feature) # TODO: make optional
features.append(feature)
else:
# run each rgb obs to independent models
@ -159,18 +195,9 @@ class MultiImageObsEncoder(ModuleAttrMixin):
assert img.shape[1:] == self.key_shape_map[key]
img = self.key_transform_map[key](img)
feature = self.key_model_map[key](img)
# feature = torch.nn.functional.relu(feature) # TODO: make optional
features.append(feature)
# process lowdim input
for key in self.low_dim_keys:
data = obs_dict[key]
if batch_size is None:
batch_size = data.shape[0]
else:
assert batch_size == data.shape[0]
assert data.shape[1:] == self.key_shape_map[key]
features.append(data)
# concatenate all features
result = torch.cat(features, dim=-1)
return result

View File

@ -7,7 +7,7 @@ import torch
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
class DiffusionPolicy(AbstractPolicy):
@ -38,6 +38,10 @@ class DiffusionPolicy(AbstractPolicy):
self.cfg = cfg
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape)
if cfg_obs_encoder.crop_shape is not None:
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model,

View File

@ -40,4 +40,23 @@ def make_policy(cfg):
raise NotImplementedError()
policy.load(cfg.policy.pretrained_model_path)
# import torch
# loaded = torch.load('/home/alexander/Downloads/dp_ema.pth')
# aligned = {}
# their_prefix = "obs_encoder.obs_nets.image.backbone"
# our_prefix = "obs_encoder.key_model_map.image.backbone"
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
# their_prefix = "obs_encoder.obs_nets.image.pool"
# our_prefix = "obs_encoder.key_model_map.image.pool"
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
# their_prefix = "obs_encoder.obs_nets.image.nets.3"
# our_prefix = "obs_encoder.key_model_map.image.out"
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
# aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')})
# missing_keys, unexpected_keys = policy.diffusion.load_state_dict(aligned, strict=False)
# assert all('_dummy_variable' in k for k in missing_keys)
# assert len(unexpected_keys) == 0
return policy

View File

@ -42,8 +42,8 @@ policy:
num_inference_steps: 100
obs_as_global_cond: ${obs_as_global_cond}
# crop_shape: null
diffusion_step_embed_dim: 256 # before 128
down_dims: [256, 512, 1024] # before [512, 1024, 2048]
diffusion_step_embed_dim: 128
down_dims: [512, 1024, 2048]
kernel_size: 5
n_groups: 8
cond_predict_scale: True
@ -76,17 +76,17 @@ noise_scheduler:
obs_encoder:
shape_meta: ${shape_meta}
# resize_shape: null
# crop_shape: [76, 76]
crop_shape: [84, 84]
# constant center crop
# random_crop: True
random_crop: True
use_group_norm: True
share_rgb_model: False
imagenet_norm: True
norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
rgb_model:
_target_: lerobot.common.policies.diffusion.pytorch_utils.get_resnet
name: resnet18
weights: null
model_name: resnet18
pretrained: false
num_keypoints: 32
ema:
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel