[Port HIL-SERL] Add HF vision encoder option in SAC (#651)
Added support with custom pretrained vision encoder to the modeling sac implementation. Great job @ChorntonYoel !
This commit is contained in:
parent
7c89bd1018
commit
f1c8bfe01e
|
@ -74,7 +74,23 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
|||
|
||||
image_transforms = None
|
||||
if cfg.training.image_transforms.enable:
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
default_tf = OmegaConf.create(
|
||||
{
|
||||
"brightness": {"weight": 0.0, "min_max": None},
|
||||
"contrast": {"weight": 0.0, "min_max": None},
|
||||
"saturation": {"weight": 0.0, "min_max": None},
|
||||
"hue": {"weight": 0.0, "min_max": None},
|
||||
"sharpness": {"weight": 0.0, "min_max": None},
|
||||
"max_num_transforms": None,
|
||||
"random_order": False,
|
||||
"image_size": None,
|
||||
"interpolation": None,
|
||||
"image_mean": None,
|
||||
"image_std": None,
|
||||
}
|
||||
)
|
||||
cfg_tf = OmegaConf.merge(OmegaConf.create(default_tf), cfg.training.image_transforms)
|
||||
|
||||
image_transforms = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
|
@ -88,6 +104,10 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
|||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
image_size=(cfg_tf.image_size.height, cfg_tf.image_size.width) if cfg_tf.image_size else None,
|
||||
interpolation=cfg_tf.interpolation,
|
||||
image_mean=cfg_tf.image_mean,
|
||||
image_std=cfg_tf.image_std,
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
|
|
|
@ -150,6 +150,10 @@ def get_image_transforms(
|
|||
sharpness_min_max: tuple[float, float] | None = None,
|
||||
max_num_transforms: int | None = None,
|
||||
random_order: bool = False,
|
||||
interpolation: str | None = None,
|
||||
image_size: tuple[int, int] | None = None,
|
||||
image_mean: list[float] | None = None,
|
||||
image_std: list[float] | None = None,
|
||||
):
|
||||
def check_value(name, weight, min_max):
|
||||
if min_max is not None:
|
||||
|
@ -170,6 +174,18 @@ def get_image_transforms(
|
|||
|
||||
weights = []
|
||||
transforms = []
|
||||
if image_size is not None:
|
||||
interpolations = [interpolation.value for interpolation in v2.InterpolationMode]
|
||||
if interpolation is None:
|
||||
# Use BICUBIC as default interpolation
|
||||
interpolation_mode = v2.InterpolationMode.BICUBIC
|
||||
elif interpolation in interpolations:
|
||||
interpolation_mode = v2.InterpolationMode(interpolation)
|
||||
else:
|
||||
raise ValueError("The interpolation passed is not supported")
|
||||
# Weight for resizing is always 1
|
||||
weights.append(1.0)
|
||||
transforms.append(v2.Resize(size=(image_size[0], image_size[1]), interpolation=interpolation_mode))
|
||||
if brightness_min_max is not None and brightness_weight > 0.0:
|
||||
weights.append(brightness_weight)
|
||||
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
|
||||
|
@ -185,6 +201,15 @@ def get_image_transforms(
|
|||
if sharpness_min_max is not None and sharpness_weight > 0.0:
|
||||
weights.append(sharpness_weight)
|
||||
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
|
||||
if image_mean is not None and image_std is not None:
|
||||
# Weight for normalization is always 1
|
||||
weights.append(1.0)
|
||||
transforms.append(
|
||||
v2.Normalize(
|
||||
mean=image_mean,
|
||||
std=image_std,
|
||||
)
|
||||
)
|
||||
|
||||
n_subset = len(transforms)
|
||||
if max_num_transforms is not None:
|
||||
|
|
|
@ -55,6 +55,7 @@ class SACConfig:
|
|||
)
|
||||
camera_number: int = 1
|
||||
# Add type annotations for these fields:
|
||||
vision_encoder_name: str = field(default="microsoft/resnet-18")
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = False
|
||||
discount: float = 0.99
|
||||
|
|
|
@ -473,54 +473,61 @@ class SACObservationEncoder(nn.Module):
|
|||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.has_pretrained_vision_encoder = False
|
||||
if "observation.image" in config.input_shapes:
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=config.input_shapes["observation.image"][0],
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=5,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.camera_number = config.camera_number
|
||||
self.aggregation_size: int = 0
|
||||
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
with torch.inference_mode():
|
||||
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
sequential=nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(
|
||||
in_features=np.prod(out_shape) * self.camera_number, out_features=config.latent_dim
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
if self.config.vision_encoder_name is not None:
|
||||
self.has_pretrained_vision_encoder = True
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder()
|
||||
self.freeze_encoder()
|
||||
self.image_enc_proj = nn.Sequential(
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=config.input_shapes["observation.image"][0],
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=5,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
with torch.inference_mode():
|
||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
)
|
||||
self.aggregation_size += config.latent_dim * self.camera_number
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
|
@ -541,10 +548,27 @@ class SACObservationEncoder(nn.Module):
|
|||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
|
||||
def _load_pretrained_vision_encoder(self):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
self.image_enc_layers = AutoModel.from_pretrained(self.config.vision_encoder_name)
|
||||
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
|
||||
elif hasattr(self.image_enc_layers, "fc"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
|
||||
else:
|
||||
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
|
||||
return self.image_enc_layers, self.image_enc_out_shape
|
||||
|
||||
def freeze_encoder(self):
|
||||
"""Freeze all parameters in the encoder"""
|
||||
for param in self.image_enc_layers.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
|
||||
|
@ -555,7 +579,13 @@ class SACObservationEncoder(nn.Module):
|
|||
# Concatenate all images along the channel dimension.
|
||||
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
|
||||
for image_key in image_keys:
|
||||
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]))
|
||||
if self.has_pretrained_vision_encoder:
|
||||
enc_feat = self.image_enc_layers(obs_dict[image_key]).pooler_output
|
||||
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
|
||||
else:
|
||||
enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
|
||||
|
||||
feat.append(enc_feat)
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
|
|
Loading…
Reference in New Issue