[Port HIL-SERL] Add HF vision encoder option in SAC (#651)

Added support with custom pretrained vision encoder to the modeling sac implementation. Great job @ChorntonYoel !
This commit is contained in:
Yoel 2025-01-31 09:42:13 +01:00 committed by Michel Aractingi
parent 7c89bd1018
commit f1c8bfe01e
4 changed files with 123 additions and 47 deletions

View File

@ -74,7 +74,23 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
image_transforms = None
if cfg.training.image_transforms.enable:
cfg_tf = cfg.training.image_transforms
default_tf = OmegaConf.create(
{
"brightness": {"weight": 0.0, "min_max": None},
"contrast": {"weight": 0.0, "min_max": None},
"saturation": {"weight": 0.0, "min_max": None},
"hue": {"weight": 0.0, "min_max": None},
"sharpness": {"weight": 0.0, "min_max": None},
"max_num_transforms": None,
"random_order": False,
"image_size": None,
"interpolation": None,
"image_mean": None,
"image_std": None,
}
)
cfg_tf = OmegaConf.merge(OmegaConf.create(default_tf), cfg.training.image_transforms)
image_transforms = get_image_transforms(
brightness_weight=cfg_tf.brightness.weight,
brightness_min_max=cfg_tf.brightness.min_max,
@ -88,6 +104,10 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
sharpness_min_max=cfg_tf.sharpness.min_max,
max_num_transforms=cfg_tf.max_num_transforms,
random_order=cfg_tf.random_order,
image_size=(cfg_tf.image_size.height, cfg_tf.image_size.width) if cfg_tf.image_size else None,
interpolation=cfg_tf.interpolation,
image_mean=cfg_tf.image_mean,
image_std=cfg_tf.image_std,
)
if isinstance(cfg.dataset_repo_id, str):

View File

@ -150,6 +150,10 @@ def get_image_transforms(
sharpness_min_max: tuple[float, float] | None = None,
max_num_transforms: int | None = None,
random_order: bool = False,
interpolation: str | None = None,
image_size: tuple[int, int] | None = None,
image_mean: list[float] | None = None,
image_std: list[float] | None = None,
):
def check_value(name, weight, min_max):
if min_max is not None:
@ -170,6 +174,18 @@ def get_image_transforms(
weights = []
transforms = []
if image_size is not None:
interpolations = [interpolation.value for interpolation in v2.InterpolationMode]
if interpolation is None:
# Use BICUBIC as default interpolation
interpolation_mode = v2.InterpolationMode.BICUBIC
elif interpolation in interpolations:
interpolation_mode = v2.InterpolationMode(interpolation)
else:
raise ValueError("The interpolation passed is not supported")
# Weight for resizing is always 1
weights.append(1.0)
transforms.append(v2.Resize(size=(image_size[0], image_size[1]), interpolation=interpolation_mode))
if brightness_min_max is not None and brightness_weight > 0.0:
weights.append(brightness_weight)
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
@ -185,6 +201,15 @@ def get_image_transforms(
if sharpness_min_max is not None and sharpness_weight > 0.0:
weights.append(sharpness_weight)
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
if image_mean is not None and image_std is not None:
# Weight for normalization is always 1
weights.append(1.0)
transforms.append(
v2.Normalize(
mean=image_mean,
std=image_std,
)
)
n_subset = len(transforms)
if max_num_transforms is not None:

View File

@ -55,6 +55,7 @@ class SACConfig:
)
camera_number: int = 1
# Add type annotations for these fields:
vision_encoder_name: str = field(default="microsoft/resnet-18")
image_encoder_hidden_dim: int = 32
shared_encoder: bool = False
discount: float = 0.99

View File

@ -473,54 +473,61 @@ class SACObservationEncoder(nn.Module):
"""
super().__init__()
self.config = config
self.has_pretrained_vision_encoder = False
if "observation.image" in config.input_shapes:
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
in_channels=config.input_shapes["observation.image"][0],
out_channels=config.image_encoder_hidden_dim,
kernel_size=7,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=5,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
)
self.camera_number = config.camera_number
self.aggregation_size: int = 0
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
sequential=nn.Sequential(
nn.Flatten(),
nn.Linear(
in_features=np.prod(out_shape) * self.camera_number, out_features=config.latent_dim
),
nn.LayerNorm(normalized_shape=config.latent_dim),
if self.config.vision_encoder_name is not None:
self.has_pretrained_vision_encoder = True
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder()
self.freeze_encoder()
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
)
else:
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
in_channels=config.input_shapes["observation.image"][0],
out_channels=config.image_encoder_hidden_dim,
kernel_size=7,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=5,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
)
self.aggregation_size += config.latent_dim * self.camera_number
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
@ -541,10 +548,27 @@ class SACObservationEncoder(nn.Module):
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.aggregation_size += config.latent_dim
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
def _load_pretrained_vision_encoder(self):
"""Set up CNN encoder"""
from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained(self.config.vision_encoder_name)
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"):
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
else:
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
return self.image_enc_layers, self.image_enc_out_shape
def freeze_encoder(self):
"""Freeze all parameters in the encoder"""
for param in self.image_enc_layers.parameters():
param.requires_grad = False
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
@ -555,7 +579,13 @@ class SACObservationEncoder(nn.Module):
# Concatenate all images along the channel dimension.
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
for image_key in image_keys:
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]))
if self.has_pretrained_vision_encoder:
enc_feat = self.image_enc_layers(obs_dict[image_key]).pooler_output
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
else:
enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
feat.append(enc_feat)
if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes: