ready for review
This commit is contained in:
parent
9512d1d2f3
commit
98484ac68e
|
@ -1,5 +1,5 @@
|
||||||
import copy
|
import copy
|
||||||
from typing import Dict, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import timm
|
import timm
|
||||||
import torch
|
import torch
|
||||||
|
@ -46,7 +46,7 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
||||||
share_rgb_model: bool = False,
|
share_rgb_model: bool = False,
|
||||||
# renormalize rgb input with imagenet normalization
|
# renormalize rgb input with imagenet normalization
|
||||||
# assuming input in [0,1]
|
# assuming input in [0,1]
|
||||||
imagenet_norm: bool = False,
|
norm_mean_std: Optional[tuple[float, float]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Assumes rgb input: B,C,H,W
|
Assumes rgb input: B,C,H,W
|
||||||
|
@ -120,13 +120,9 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
||||||
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
|
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
|
||||||
# configure normalizer
|
# configure normalizer
|
||||||
this_normalizer = nn.Identity()
|
this_normalizer = nn.Identity()
|
||||||
if imagenet_norm:
|
if norm_mean_std is not None:
|
||||||
# TODO(rcadene): move normalizer to dataset and env
|
|
||||||
this_normalizer = torchvision.transforms.Normalize(
|
this_normalizer = torchvision.transforms.Normalize(
|
||||||
# Note: This matches the normalization in the original impl. for PushT Image. This may not be
|
mean=norm_mean_std[0], std=norm_mean_std[1]
|
||||||
# the case for other tasks.
|
|
||||||
mean=[127.5, 127.5, 127.5],
|
|
||||||
std=[127.5, 127.5, 127.5],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
|
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
|
||||||
|
|
|
@ -81,7 +81,7 @@ obs_encoder:
|
||||||
# random_crop: True
|
# random_crop: True
|
||||||
use_group_norm: True
|
use_group_norm: True
|
||||||
share_rgb_model: False
|
share_rgb_model: False
|
||||||
imagenet_norm: True
|
norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
|
||||||
|
|
||||||
rgb_model:
|
rgb_model:
|
||||||
model_name: resnet18
|
model_name: resnet18
|
||||||
|
|
Loading…
Reference in New Issue