ready for review
This commit is contained in:
parent
9512d1d2f3
commit
98484ac68e
|
@ -1,5 +1,5 @@
|
|||
import copy
|
||||
from typing import Dict, Tuple, Union
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import timm
|
||||
import torch
|
||||
|
@ -46,7 +46,7 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
|||
share_rgb_model: bool = False,
|
||||
# renormalize rgb input with imagenet normalization
|
||||
# assuming input in [0,1]
|
||||
imagenet_norm: bool = False,
|
||||
norm_mean_std: Optional[tuple[float, float]] = None,
|
||||
):
|
||||
"""
|
||||
Assumes rgb input: B,C,H,W
|
||||
|
@ -120,13 +120,9 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
|||
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
|
||||
# configure normalizer
|
||||
this_normalizer = nn.Identity()
|
||||
if imagenet_norm:
|
||||
# TODO(rcadene): move normalizer to dataset and env
|
||||
if norm_mean_std is not None:
|
||||
this_normalizer = torchvision.transforms.Normalize(
|
||||
# Note: This matches the normalization in the original impl. for PushT Image. This may not be
|
||||
# the case for other tasks.
|
||||
mean=[127.5, 127.5, 127.5],
|
||||
std=[127.5, 127.5, 127.5],
|
||||
mean=norm_mean_std[0], std=norm_mean_std[1]
|
||||
)
|
||||
|
||||
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
|
||||
|
|
|
@ -81,7 +81,7 @@ obs_encoder:
|
|||
# random_crop: True
|
||||
use_group_norm: True
|
||||
share_rgb_model: False
|
||||
imagenet_norm: True
|
||||
norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
|
||||
|
||||
rgb_model:
|
||||
model_name: resnet18
|
||||
|
|
Loading…
Reference in New Issue