backup wip
This commit is contained in:
parent
5332766a82
commit
32e3f71dd1
|
@ -190,11 +190,10 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
|
||||||
|
|
||||||
# run sampling
|
# run sampling
|
||||||
nsample = self.conditional_sample(
|
nsample = self.conditional_sample(
|
||||||
cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond, **self.kwargs
|
cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond
|
||||||
)
|
)
|
||||||
|
|
||||||
action_pred = nsample[..., :action_dim]
|
action_pred = nsample[..., :action_dim]
|
||||||
|
|
||||||
# get action
|
# get action
|
||||||
start = n_obs_steps - 1
|
start = n_obs_steps - 1
|
||||||
end = start + self.n_action_steps
|
end = start + self.n_action_steps
|
||||||
|
|
|
@ -1,15 +1,40 @@
|
||||||
import copy
|
import copy
|
||||||
from typing import Dict, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import timm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision
|
import torchvision
|
||||||
|
from robomimic.models.base_nets import SpatialSoftmax
|
||||||
|
|
||||||
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
|
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
|
||||||
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
||||||
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
|
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
|
||||||
|
|
||||||
|
|
||||||
|
class RgbEncoder(nn.Module):
|
||||||
|
"""Following `VisualCore` from Robomimic 0.2.0."""
|
||||||
|
|
||||||
|
def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32):
|
||||||
|
"""
|
||||||
|
input_shape: channel-first input shape (C, H, W)
|
||||||
|
resnet_name: a timm model name.
|
||||||
|
pretrained: whether to use timm pretrained weights.
|
||||||
|
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="")
|
||||||
|
# self.backbone = ResNet18Conv(input_channel=input_shape[0])
|
||||||
|
# Figure out the feature map shape.
|
||||||
|
with torch.inference_mode():
|
||||||
|
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
|
||||||
|
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
|
||||||
|
self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))
|
||||||
|
|
||||||
|
|
||||||
class MultiImageObsEncoder(ModuleAttrMixin):
|
class MultiImageObsEncoder(ModuleAttrMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -24,7 +49,7 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
||||||
share_rgb_model: bool = False,
|
share_rgb_model: bool = False,
|
||||||
# renormalize rgb input with imagenet normalization
|
# renormalize rgb input with imagenet normalization
|
||||||
# assuming input in [0,1]
|
# assuming input in [0,1]
|
||||||
imagenet_norm: bool = False,
|
norm_mean_std: Optional[tuple[float, float]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Assumes rgb input: B,C,H,W
|
Assumes rgb input: B,C,H,W
|
||||||
|
@ -98,10 +123,9 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
||||||
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
|
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
|
||||||
# configure normalizer
|
# configure normalizer
|
||||||
this_normalizer = nn.Identity()
|
this_normalizer = nn.Identity()
|
||||||
if imagenet_norm:
|
if norm_mean_std is not None:
|
||||||
# TODO(rcadene): move normalizer to dataset and env
|
|
||||||
this_normalizer = torchvision.transforms.Normalize(
|
this_normalizer = torchvision.transforms.Normalize(
|
||||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
mean=norm_mean_std[0], std=norm_mean_std[1]
|
||||||
)
|
)
|
||||||
|
|
||||||
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
|
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
|
||||||
|
@ -124,6 +148,17 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
||||||
def forward(self, obs_dict):
|
def forward(self, obs_dict):
|
||||||
batch_size = None
|
batch_size = None
|
||||||
features = []
|
features = []
|
||||||
|
|
||||||
|
# process lowdim input
|
||||||
|
for key in self.low_dim_keys:
|
||||||
|
data = obs_dict[key]
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = data.shape[0]
|
||||||
|
else:
|
||||||
|
assert batch_size == data.shape[0]
|
||||||
|
assert data.shape[1:] == self.key_shape_map[key]
|
||||||
|
features.append(data)
|
||||||
|
|
||||||
# process rgb input
|
# process rgb input
|
||||||
if self.share_rgb_model:
|
if self.share_rgb_model:
|
||||||
# pass all rgb obs to rgb model
|
# pass all rgb obs to rgb model
|
||||||
|
@ -147,6 +182,7 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
||||||
feature = torch.moveaxis(feature, 0, 1)
|
feature = torch.moveaxis(feature, 0, 1)
|
||||||
# (B,N*D)
|
# (B,N*D)
|
||||||
feature = feature.reshape(batch_size, -1)
|
feature = feature.reshape(batch_size, -1)
|
||||||
|
# feature = torch.nn.functional.relu(feature) # TODO: make optional
|
||||||
features.append(feature)
|
features.append(feature)
|
||||||
else:
|
else:
|
||||||
# run each rgb obs to independent models
|
# run each rgb obs to independent models
|
||||||
|
@ -159,18 +195,9 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
||||||
assert img.shape[1:] == self.key_shape_map[key]
|
assert img.shape[1:] == self.key_shape_map[key]
|
||||||
img = self.key_transform_map[key](img)
|
img = self.key_transform_map[key](img)
|
||||||
feature = self.key_model_map[key](img)
|
feature = self.key_model_map[key](img)
|
||||||
|
# feature = torch.nn.functional.relu(feature) # TODO: make optional
|
||||||
features.append(feature)
|
features.append(feature)
|
||||||
|
|
||||||
# process lowdim input
|
|
||||||
for key in self.low_dim_keys:
|
|
||||||
data = obs_dict[key]
|
|
||||||
if batch_size is None:
|
|
||||||
batch_size = data.shape[0]
|
|
||||||
else:
|
|
||||||
assert batch_size == data.shape[0]
|
|
||||||
assert data.shape[1:] == self.key_shape_map[key]
|
|
||||||
features.append(data)
|
|
||||||
|
|
||||||
# concatenate all features
|
# concatenate all features
|
||||||
result = torch.cat(features, dim=-1)
|
result = torch.cat(features, dim=-1)
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
||||||
from lerobot.common.policies.abstract import AbstractPolicy
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
|
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPolicy(AbstractPolicy):
|
class DiffusionPolicy(AbstractPolicy):
|
||||||
|
@ -38,6 +38,10 @@ class DiffusionPolicy(AbstractPolicy):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||||
|
rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape)
|
||||||
|
if cfg_obs_encoder.crop_shape is not None:
|
||||||
|
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
|
||||||
|
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
|
||||||
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
|
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
|
||||||
obs_encoder = MultiImageObsEncoder(
|
obs_encoder = MultiImageObsEncoder(
|
||||||
rgb_model=rgb_model,
|
rgb_model=rgb_model,
|
||||||
|
|
|
@ -40,4 +40,23 @@ def make_policy(cfg):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
policy.load(cfg.policy.pretrained_model_path)
|
policy.load(cfg.policy.pretrained_model_path)
|
||||||
|
|
||||||
|
# import torch
|
||||||
|
# loaded = torch.load('/home/alexander/Downloads/dp_ema.pth')
|
||||||
|
# aligned = {}
|
||||||
|
|
||||||
|
# their_prefix = "obs_encoder.obs_nets.image.backbone"
|
||||||
|
# our_prefix = "obs_encoder.key_model_map.image.backbone"
|
||||||
|
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
|
||||||
|
# their_prefix = "obs_encoder.obs_nets.image.pool"
|
||||||
|
# our_prefix = "obs_encoder.key_model_map.image.pool"
|
||||||
|
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
|
||||||
|
# their_prefix = "obs_encoder.obs_nets.image.nets.3"
|
||||||
|
# our_prefix = "obs_encoder.key_model_map.image.out"
|
||||||
|
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
|
||||||
|
|
||||||
|
# aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')})
|
||||||
|
# missing_keys, unexpected_keys = policy.diffusion.load_state_dict(aligned, strict=False)
|
||||||
|
# assert all('_dummy_variable' in k for k in missing_keys)
|
||||||
|
# assert len(unexpected_keys) == 0
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
|
@ -42,8 +42,8 @@ policy:
|
||||||
num_inference_steps: 100
|
num_inference_steps: 100
|
||||||
obs_as_global_cond: ${obs_as_global_cond}
|
obs_as_global_cond: ${obs_as_global_cond}
|
||||||
# crop_shape: null
|
# crop_shape: null
|
||||||
diffusion_step_embed_dim: 256 # before 128
|
diffusion_step_embed_dim: 128
|
||||||
down_dims: [256, 512, 1024] # before [512, 1024, 2048]
|
down_dims: [512, 1024, 2048]
|
||||||
kernel_size: 5
|
kernel_size: 5
|
||||||
n_groups: 8
|
n_groups: 8
|
||||||
cond_predict_scale: True
|
cond_predict_scale: True
|
||||||
|
@ -76,17 +76,17 @@ noise_scheduler:
|
||||||
obs_encoder:
|
obs_encoder:
|
||||||
shape_meta: ${shape_meta}
|
shape_meta: ${shape_meta}
|
||||||
# resize_shape: null
|
# resize_shape: null
|
||||||
# crop_shape: [76, 76]
|
crop_shape: [84, 84]
|
||||||
# constant center crop
|
# constant center crop
|
||||||
# random_crop: True
|
random_crop: True
|
||||||
use_group_norm: True
|
use_group_norm: True
|
||||||
share_rgb_model: False
|
share_rgb_model: False
|
||||||
imagenet_norm: True
|
norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
|
||||||
|
|
||||||
rgb_model:
|
rgb_model:
|
||||||
_target_: lerobot.common.policies.diffusion.pytorch_utils.get_resnet
|
model_name: resnet18
|
||||||
name: resnet18
|
pretrained: false
|
||||||
weights: null
|
num_keypoints: 32
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
|
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
|
||||||
|
|
Loading…
Reference in New Issue