diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index f6164ed1..2f280372 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -74,7 +74,23 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData image_transforms = None if cfg.training.image_transforms.enable: - cfg_tf = cfg.training.image_transforms + default_tf = OmegaConf.create( + { + "brightness": {"weight": 0.0, "min_max": None}, + "contrast": {"weight": 0.0, "min_max": None}, + "saturation": {"weight": 0.0, "min_max": None}, + "hue": {"weight": 0.0, "min_max": None}, + "sharpness": {"weight": 0.0, "min_max": None}, + "max_num_transforms": None, + "random_order": False, + "image_size": None, + "interpolation": None, + "image_mean": None, + "image_std": None, + } + ) + cfg_tf = OmegaConf.merge(OmegaConf.create(default_tf), cfg.training.image_transforms) + image_transforms = get_image_transforms( brightness_weight=cfg_tf.brightness.weight, brightness_min_max=cfg_tf.brightness.min_max, @@ -88,6 +104,10 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData sharpness_min_max=cfg_tf.sharpness.min_max, max_num_transforms=cfg_tf.max_num_transforms, random_order=cfg_tf.random_order, + image_size=(cfg_tf.image_size.height, cfg_tf.image_size.width) if cfg_tf.image_size else None, + interpolation=cfg_tf.interpolation, + image_mean=cfg_tf.image_mean, + image_std=cfg_tf.image_std, ) if isinstance(cfg.dataset_repo_id, str): diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py index 899f0d66..1a72e68e 100644 --- a/lerobot/common/datasets/transforms.py +++ b/lerobot/common/datasets/transforms.py @@ -150,6 +150,10 @@ def get_image_transforms( sharpness_min_max: tuple[float, float] | None = None, max_num_transforms: int | None = None, random_order: bool = False, + interpolation: str | None = None, + image_size: tuple[int, int] | None = None, + image_mean: list[float] | None = None, + image_std: list[float] | None = None, ): def check_value(name, weight, min_max): if min_max is not None: @@ -170,6 +174,18 @@ def get_image_transforms( weights = [] transforms = [] + if image_size is not None: + interpolations = [interpolation.value for interpolation in v2.InterpolationMode] + if interpolation is None: + # Use BICUBIC as default interpolation + interpolation_mode = v2.InterpolationMode.BICUBIC + elif interpolation in interpolations: + interpolation_mode = v2.InterpolationMode(interpolation) + else: + raise ValueError("The interpolation passed is not supported") + # Weight for resizing is always 1 + weights.append(1.0) + transforms.append(v2.Resize(size=(image_size[0], image_size[1]), interpolation=interpolation_mode)) if brightness_min_max is not None and brightness_weight > 0.0: weights.append(brightness_weight) transforms.append(v2.ColorJitter(brightness=brightness_min_max)) @@ -185,6 +201,15 @@ def get_image_transforms( if sharpness_min_max is not None and sharpness_weight > 0.0: weights.append(sharpness_weight) transforms.append(SharpnessJitter(sharpness=sharpness_min_max)) + if image_mean is not None and image_std is not None: + # Weight for normalization is always 1 + weights.append(1.0) + transforms.append( + v2.Normalize( + mean=image_mean, + std=image_std, + ) + ) n_subset = len(transforms) if max_num_transforms is not None: diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 904679e8..3c6344de 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -55,6 +55,7 @@ class SACConfig: ) camera_number: int = 1 # Add type annotations for these fields: + vision_encoder_name: str = field(default="microsoft/resnet-18") image_encoder_hidden_dim: int = 32 shared_encoder: bool = False discount: float = 0.99 diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 64688b1b..bd6e9ef2 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -473,54 +473,61 @@ class SACObservationEncoder(nn.Module): """ super().__init__() self.config = config + self.has_pretrained_vision_encoder = False if "observation.image" in config.input_shapes: - self.image_enc_layers = nn.Sequential( - nn.Conv2d( - in_channels=config.input_shapes["observation.image"][0], - out_channels=config.image_encoder_hidden_dim, - kernel_size=7, - stride=2, - ), - nn.ReLU(), - nn.Conv2d( - in_channels=config.image_encoder_hidden_dim, - out_channels=config.image_encoder_hidden_dim, - kernel_size=5, - stride=2, - ), - nn.ReLU(), - nn.Conv2d( - in_channels=config.image_encoder_hidden_dim, - out_channels=config.image_encoder_hidden_dim, - kernel_size=3, - stride=2, - ), - nn.ReLU(), - nn.Conv2d( - in_channels=config.image_encoder_hidden_dim, - out_channels=config.image_encoder_hidden_dim, - kernel_size=3, - stride=2, - ), - nn.ReLU(), - ) self.camera_number = config.camera_number self.aggregation_size: int = 0 - - dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) - with torch.inference_mode(): - out_shape = self.image_enc_layers(dummy_batch).shape[1:] - self.image_enc_layers.extend( - sequential=nn.Sequential( - nn.Flatten(), - nn.Linear( - in_features=np.prod(out_shape) * self.camera_number, out_features=config.latent_dim - ), - nn.LayerNorm(normalized_shape=config.latent_dim), + if self.config.vision_encoder_name is not None: + self.has_pretrained_vision_encoder = True + self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder() + self.freeze_encoder() + self.image_enc_proj = nn.Sequential( + nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), + nn.LayerNorm(config.latent_dim), nn.Tanh(), ) - ) - + else: + self.image_enc_layers = nn.Sequential( + nn.Conv2d( + in_channels=config.input_shapes["observation.image"][0], + out_channels=config.image_encoder_hidden_dim, + kernel_size=7, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=5, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=3, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=3, + stride=2, + ), + nn.ReLU(), + ) + dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) + with torch.inference_mode(): + self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:] + self.image_enc_layers.extend( + nn.Sequential( + nn.Flatten(), + nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Tanh(), + ) + ) self.aggregation_size += config.latent_dim * self.camera_number if "observation.state" in config.input_shapes: self.state_enc_layers = nn.Sequential( @@ -541,10 +548,27 @@ class SACObservationEncoder(nn.Module): nn.LayerNorm(normalized_shape=config.latent_dim), nn.Tanh(), ) - - self.aggregation_size += config.latent_dim + self.aggregation_size += config.latent_dim self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) + def _load_pretrained_vision_encoder(self): + """Set up CNN encoder""" + from transformers import AutoModel + + self.image_enc_layers = AutoModel.from_pretrained(self.config.vision_encoder_name) + if hasattr(self.image_enc_layers.config, "hidden_sizes"): + self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension + elif hasattr(self.image_enc_layers, "fc"): + self.image_enc_out_shape = self.image_enc_layers.fc.in_features + else: + raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN") + return self.image_enc_layers, self.image_enc_out_shape + + def freeze_encoder(self): + """Freeze all parameters in the encoder""" + for param in self.image_enc_layers.parameters(): + param.requires_grad = False + def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: """Encode the image and/or state vector. @@ -555,7 +579,13 @@ class SACObservationEncoder(nn.Module): # Concatenate all images along the channel dimension. image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")] for image_key in image_keys: - feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])) + if self.has_pretrained_vision_encoder: + enc_feat = self.image_enc_layers(obs_dict[image_key]).pooler_output + enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1)) + else: + enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]) + + feat.append(enc_feat) if "observation.environment_state" in self.config.input_shapes: feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) if "observation.state" in self.config.input_shapes: