ready for review

This commit is contained in:
Alexander Soare 2024-03-12 21:59:01 +00:00
parent 9512d1d2f3
commit 98484ac68e
2 changed files with 5 additions and 9 deletions

View File

@ -1,5 +1,5 @@
import copy
from typing import Dict, Tuple, Union
from typing import Dict, Optional, Tuple, Union
import timm
import torch
@ -46,7 +46,7 @@ class MultiImageObsEncoder(ModuleAttrMixin):
share_rgb_model: bool = False,
# renormalize rgb input with imagenet normalization
# assuming input in [0,1]
imagenet_norm: bool = False,
norm_mean_std: Optional[tuple[float, float]] = None,
):
"""
Assumes rgb input: B,C,H,W
@ -120,13 +120,9 @@ class MultiImageObsEncoder(ModuleAttrMixin):
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
# configure normalizer
this_normalizer = nn.Identity()
if imagenet_norm:
# TODO(rcadene): move normalizer to dataset and env
if norm_mean_std is not None:
this_normalizer = torchvision.transforms.Normalize(
# Note: This matches the normalization in the original impl. for PushT Image. This may not be
# the case for other tasks.
mean=[127.5, 127.5, 127.5],
std=[127.5, 127.5, 127.5],
mean=norm_mean_std[0], std=norm_mean_std[1]
)
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)

View File

@ -81,7 +81,7 @@ obs_encoder:
# random_crop: True
use_group_norm: True
share_rgb_model: False
imagenet_norm: True
norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
rgb_model:
model_name: resnet18