From 98484ac68ed36247b3721c072b5c0637ce33570c Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 12 Mar 2024 21:59:01 +0000 Subject: [PATCH] ready for review --- .../diffusion/model/multi_image_obs_encoder.py | 12 ++++-------- lerobot/configs/policy/diffusion.yaml | 2 +- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py index 91472dd5..6a1d3c0d 100644 --- a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py +++ b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py @@ -1,5 +1,5 @@ import copy -from typing import Dict, Tuple, Union +from typing import Dict, Optional, Tuple, Union import timm import torch @@ -46,7 +46,7 @@ class MultiImageObsEncoder(ModuleAttrMixin): share_rgb_model: bool = False, # renormalize rgb input with imagenet normalization # assuming input in [0,1] - imagenet_norm: bool = False, + norm_mean_std: Optional[tuple[float, float]] = None, ): """ Assumes rgb input: B,C,H,W @@ -120,13 +120,9 @@ class MultiImageObsEncoder(ModuleAttrMixin): this_normalizer = torchvision.transforms.CenterCrop(size=(h, w)) # configure normalizer this_normalizer = nn.Identity() - if imagenet_norm: - # TODO(rcadene): move normalizer to dataset and env + if norm_mean_std is not None: this_normalizer = torchvision.transforms.Normalize( - # Note: This matches the normalization in the original impl. for PushT Image. This may not be - # the case for other tasks. - mean=[127.5, 127.5, 127.5], - std=[127.5, 127.5, 127.5], + mean=norm_mean_std[0], std=norm_mean_std[1] ) this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer) diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index f07e4754..7de44102 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -81,7 +81,7 @@ obs_encoder: # random_crop: True use_group_norm: True share_rgb_model: False - imagenet_norm: True + norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs) rgb_model: model_name: resnet18