diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 2246bf8c..b5bfb36e 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -187,8 +187,8 @@ class SACPolicy( """Select action for inference/evaluation""" # We cached the encoder output to avoid recomputing it observations_features = None - if self.shared_encoder and self.actor.encoder is not None: - observations_features = self.actor.encoder(batch) + if self.shared_encoder: + observations_features = self.actor.encoder.get_image_features(batch) actions, _, _ = self.actor(batch, observations_features) actions = self.unnormalize_outputs({"action": actions})["action"] @@ -484,6 +484,109 @@ class SACPolicy( return actor_loss +class SACObservationEncoder(nn.Module): + """Encode image and/or state vector observations.""" + + def __init__(self, config: SACConfig, input_normalizer: nn.Module): + """ + Creates encoders for pixel and/or state modalities. + """ + super().__init__() + self.config = config + self.input_normalization = input_normalizer + self.has_pretrained_vision_encoder = False + self.parameters_to_optimize = [] + + self.aggregation_size: int = 0 + if any("observation.image" in key for key in config.input_features): + self.camera_number = config.camera_number + + if self.config.vision_encoder_name is not None: + self.image_enc_layers = PretrainedImageEncoder(config) + self.has_pretrained_vision_encoder = True + else: + self.image_enc_layers = DefaultImageEncoder(config) + + self.aggregation_size += config.latent_dim * self.camera_number + + if config.freeze_vision_encoder: + freeze_image_encoder(self.image_enc_layers) + else: + self.parameters_to_optimize += list(self.image_enc_layers.parameters()) + self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")] + + if "observation.state" in config.input_features: + self.state_enc_layers = nn.Sequential( + nn.Linear( + in_features=config.input_features["observation.state"].shape[0], + out_features=config.latent_dim, + ), + nn.LayerNorm(normalized_shape=config.latent_dim), + nn.Tanh(), + ) + self.aggregation_size += config.latent_dim + + self.parameters_to_optimize += list(self.state_enc_layers.parameters()) + + if "observation.environment_state" in config.input_features: + self.env_state_enc_layers = nn.Sequential( + nn.Linear( + in_features=config.input_features["observation.environment_state"].shape[0], + out_features=config.latent_dim, + ), + nn.LayerNorm(normalized_shape=config.latent_dim), + nn.Tanh(), + ) + self.aggregation_size += config.latent_dim + self.parameters_to_optimize += list(self.env_state_enc_layers.parameters()) + + self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) + self.parameters_to_optimize += list(self.aggregation_layer.parameters()) + + def forward( + self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None + ) -> Tensor: + """Encode the image and/or state vector. + + Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken + over all features. + """ + feat = [] + obs_dict = self.input_normalization(obs_dict) + if len(self.all_image_keys) > 0 and vision_encoder_cache is None: + vision_encoder_cache = self.get_image_features(obs_dict) + feat.append(vision_encoder_cache) + + if vision_encoder_cache is not None: + feat.append(vision_encoder_cache) + + if "observation.environment_state" in self.config.input_features: + feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) + if "observation.state" in self.config.input_features: + feat.append(self.state_enc_layers(obs_dict["observation.state"])) + + features = torch.cat(tensors=feat, dim=-1) + features = self.aggregation_layer(features) + + return features + + def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor: + # [N*B, C, H, W] + if len(self.all_image_keys) > 0: + # Batch all images along the batch dimension, then encode them. + images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0) + images_batched = self.image_enc_layers(images_batched) + embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys)) + embeddings_image = torch.cat(embeddings_chunks, dim=-1) + return embeddings_image + return None + + @property + def output_dim(self) -> int: + """Returns the dimension of the encoder output""" + return self.config.latent_dim + + class MLP(nn.Module): def __init__( self, @@ -606,7 +709,7 @@ class CriticEnsemble(nn.Module): def __init__( self, - encoder: Optional[nn.Module], + encoder: SACObservationEncoder, ensemble: List[CriticHead], output_normalization: nn.Module, init_final: Optional[float] = None, @@ -638,11 +741,7 @@ class CriticEnsemble(nn.Module): actions = self.output_normalization(actions)["action"] actions = actions.to(device) - obs_enc = ( - observation_features - if observation_features is not None - else (observations if self.encoder is None else self.encoder(observations)) - ) + obs_enc = self.encoder(observations, observation_features) inputs = torch.cat([obs_enc, actions], dim=-1) @@ -659,7 +758,7 @@ class CriticEnsemble(nn.Module): class GraspCritic(nn.Module): def __init__( self, - encoder: Optional[nn.Module], + encoder: nn.Module, input_dim: int, hidden_dims: list[int], output_dim: int = 3, @@ -699,19 +798,14 @@ class GraspCritic(nn.Module): device = get_device_from_parameters(self) # Move each tensor in observations to device by cloning first to avoid inplace operations observations = {k: v.to(device) for k, v in observations.items()} - # Encode observations if encoder exists - obs_enc = ( - observation_features.to(device) - if observation_features is not None - else (observations if self.encoder is None else self.encoder(observations)) - ) + obs_enc = self.encoder(observations, vision_encoder_cache=observation_features) return self.output_layer(self.net(obs_enc)) class Policy(nn.Module): def __init__( self, - encoder: Optional[nn.Module], + encoder: SACObservationEncoder, network: nn.Module, action_dim: int, log_std_min: float = -5, @@ -722,7 +816,7 @@ class Policy(nn.Module): encoder_is_shared: bool = False, ): super().__init__() - self.encoder = encoder + self.encoder: SACObservationEncoder = encoder self.network = network self.action_dim = action_dim self.log_std_min = log_std_min @@ -765,11 +859,7 @@ class Policy(nn.Module): observation_features: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Encode observations if encoder exists - obs_enc = ( - observation_features - if observation_features is not None - else (observations if self.encoder is None else self.encoder(observations)) - ) + obs_enc = self.encoder(observations, vision_encoder_cache=observation_features) # Get network outputs outputs = self.network(obs_enc) @@ -813,96 +903,6 @@ class Policy(nn.Module): return observations -class SACObservationEncoder(nn.Module): - """Encode image and/or state vector observations.""" - - def __init__(self, config: SACConfig, input_normalizer: nn.Module): - """ - Creates encoders for pixel and/or state modalities. - """ - super().__init__() - self.config = config - self.input_normalization = input_normalizer - self.has_pretrained_vision_encoder = False - self.parameters_to_optimize = [] - - self.aggregation_size: int = 0 - if any("observation.image" in key for key in config.input_features): - self.camera_number = config.camera_number - - if self.config.vision_encoder_name is not None: - self.image_enc_layers = PretrainedImageEncoder(config) - self.has_pretrained_vision_encoder = True - else: - self.image_enc_layers = DefaultImageEncoder(config) - - self.aggregation_size += config.latent_dim * self.camera_number - - if config.freeze_vision_encoder: - freeze_image_encoder(self.image_enc_layers) - else: - self.parameters_to_optimize += list(self.image_enc_layers.parameters()) - self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")] - - if "observation.state" in config.input_features: - self.state_enc_layers = nn.Sequential( - nn.Linear( - in_features=config.input_features["observation.state"].shape[0], - out_features=config.latent_dim, - ), - nn.LayerNorm(normalized_shape=config.latent_dim), - nn.Tanh(), - ) - self.aggregation_size += config.latent_dim - - self.parameters_to_optimize += list(self.state_enc_layers.parameters()) - - if "observation.environment_state" in config.input_features: - self.env_state_enc_layers = nn.Sequential( - nn.Linear( - in_features=config.input_features["observation.environment_state"].shape[0], - out_features=config.latent_dim, - ), - nn.LayerNorm(normalized_shape=config.latent_dim), - nn.Tanh(), - ) - self.aggregation_size += config.latent_dim - self.parameters_to_optimize += list(self.env_state_enc_layers.parameters()) - - self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) - self.parameters_to_optimize += list(self.aggregation_layer.parameters()) - - def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: - """Encode the image and/or state vector. - - Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken - over all features. - """ - feat = [] - obs_dict = self.input_normalization(obs_dict) - # Batch all images along the batch dimension, then encode them. - if len(self.all_image_keys) > 0: - images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0) - images_batched = self.image_enc_layers(images_batched) - embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys)) - feat.extend(embeddings_chunks) - - if "observation.environment_state" in self.config.input_features: - feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) - if "observation.state" in self.config.input_features: - feat.append(self.state_enc_layers(obs_dict["observation.state"])) - - features = torch.cat(tensors=feat, dim=-1) - features = self.aggregation_layer(features) - - return features - - @property - def output_dim(self) -> int: - """Returns the dimension of the encoder output""" - return self.config.latent_dim - - class DefaultImageEncoder(nn.Module): def __init__(self, config: SACConfig): super().__init__() diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 65b1d9b8..37586fe9 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -775,7 +775,9 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): params=policy.actor.parameters_to_optimize, lr=cfg.policy.actor_lr, ) - optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) + optimizer_critic = torch.optim.Adam( + params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr + ) if cfg.policy.num_discrete_actions is not None: optimizer_grasp_critic = torch.optim.Adam( @@ -1024,12 +1026,8 @@ def get_observation_features( return None, None with torch.no_grad(): - observation_features = ( - policy.actor.encoder(observations) if policy.actor.encoder is not None else None - ) - next_observation_features = ( - policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None - ) + observation_features = policy.actor.encoder.get_image_features(observations) + next_observation_features = policy.actor.encoder.get_image_features(next_observations) return observation_features, next_observation_features