diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 3c6344de..7bb7f167 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -55,9 +55,10 @@ class SACConfig: ) camera_number: int = 1 # Add type annotations for these fields: - vision_encoder_name: str = field(default="microsoft/resnet-18") + vision_encoder_name: str | None = field(default="microsoft/resnet-18") + freeze_vision_encoder: bool = True image_encoder_hidden_dim: int = 32 - shared_encoder: bool = False + shared_encoder: bool = True discount: float = 0.99 temperature_init: float = 1.0 num_critics: int = 2 diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index bd6e9ef2..9faeeeb6 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -312,7 +312,7 @@ class CriticEnsemble(nn.Module): def __init__( self, encoder: Optional[nn.Module], - network_list: nn.Module, + network_list: nn.ModuleList, init_final: Optional[float] = None, ): super().__init__() @@ -320,6 +320,12 @@ class CriticEnsemble(nn.Module): self.network_list = network_list self.init_final = init_final + self.parameters_to_optimize = [] + # Handle the case where a part of the encoder if frozen + if self.encoder is not None: + self.parameters_to_optimize += list(self.encoder.parameters_to_optimize) + + self.parameters_to_optimize += list(self.network_list.parameters()) # Find the last Linear layer's output dimension for layer in reversed(network_list[0].net): if isinstance(layer, nn.Linear): @@ -342,6 +348,7 @@ class CriticEnsemble(nn.Module): self.output_layers.append(output_layer) self.output_layers = nn.ModuleList(self.output_layers) + self.parameters_to_optimize += list(self.output_layers.parameters()) def forward( self, @@ -474,61 +481,25 @@ class SACObservationEncoder(nn.Module): super().__init__() self.config = config self.has_pretrained_vision_encoder = False + self.parameters_to_optimize = [] + + self.aggregation_size: int = 0 if "observation.image" in config.input_shapes: self.camera_number = config.camera_number - self.aggregation_size: int = 0 + if self.config.vision_encoder_name is not None: + self.image_enc_layers = PretrainedImageEncoder(config) self.has_pretrained_vision_encoder = True - self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder() - self.freeze_encoder() - self.image_enc_proj = nn.Sequential( - nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), - nn.LayerNorm(config.latent_dim), - nn.Tanh(), - ) else: - self.image_enc_layers = nn.Sequential( - nn.Conv2d( - in_channels=config.input_shapes["observation.image"][0], - out_channels=config.image_encoder_hidden_dim, - kernel_size=7, - stride=2, - ), - nn.ReLU(), - nn.Conv2d( - in_channels=config.image_encoder_hidden_dim, - out_channels=config.image_encoder_hidden_dim, - kernel_size=5, - stride=2, - ), - nn.ReLU(), - nn.Conv2d( - in_channels=config.image_encoder_hidden_dim, - out_channels=config.image_encoder_hidden_dim, - kernel_size=3, - stride=2, - ), - nn.ReLU(), - nn.Conv2d( - in_channels=config.image_encoder_hidden_dim, - out_channels=config.image_encoder_hidden_dim, - kernel_size=3, - stride=2, - ), - nn.ReLU(), - ) - dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) - with torch.inference_mode(): - self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:] - self.image_enc_layers.extend( - nn.Sequential( - nn.Flatten(), - nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), - nn.LayerNorm(config.latent_dim), - nn.Tanh(), - ) - ) + self.image_enc_layers = DefaultImageEncoder(config) + self.aggregation_size += config.latent_dim * self.camera_number + + if config.freeze_vision_encoder: + freeze_image_encoder(self.image_enc_layers) + else: + self.parameters_to_optimize += list(self.image_enc_layers.parameters()) + if "observation.state" in config.input_shapes: self.state_enc_layers = nn.Sequential( nn.Linear( @@ -539,6 +510,8 @@ class SACObservationEncoder(nn.Module): ) self.aggregation_size += config.latent_dim + self.parameters_to_optimize += list(self.state_enc_layers.parameters()) + if "observation.environment_state" in config.input_shapes: self.env_state_enc_layers = nn.Sequential( nn.Linear( @@ -548,26 +521,11 @@ class SACObservationEncoder(nn.Module): nn.LayerNorm(normalized_shape=config.latent_dim), nn.Tanh(), ) - self.aggregation_size += config.latent_dim + self.aggregation_size += config.latent_dim + self.parameters_to_optimize += list(self.env_state_enc_layers.parameters()) + self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) - - def _load_pretrained_vision_encoder(self): - """Set up CNN encoder""" - from transformers import AutoModel - - self.image_enc_layers = AutoModel.from_pretrained(self.config.vision_encoder_name) - if hasattr(self.image_enc_layers.config, "hidden_sizes"): - self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension - elif hasattr(self.image_enc_layers, "fc"): - self.image_enc_out_shape = self.image_enc_layers.fc.in_features - else: - raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN") - return self.image_enc_layers, self.image_enc_out_shape - - def freeze_encoder(self): - """Freeze all parameters in the encoder""" - for param in self.image_enc_layers.parameters(): - param.requires_grad = False + self.parameters_to_optimize += list(self.aggregation_layer.parameters()) def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: """Encode the image and/or state vector. @@ -579,12 +537,10 @@ class SACObservationEncoder(nn.Module): # Concatenate all images along the channel dimension. image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")] for image_key in image_keys: - if self.has_pretrained_vision_encoder: - enc_feat = self.image_enc_layers(obs_dict[image_key]).pooler_output - enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1)) - else: - enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]) + enc_feat = self.image_enc_layers(obs_dict[image_key]) + # if not self.has_pretrained_vision_encoder: + # enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]) feat.append(enc_feat) if "observation.environment_state" in self.config.input_shapes: feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) @@ -602,10 +558,107 @@ class SACObservationEncoder(nn.Module): return self.config.latent_dim +class DefaultImageEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.image_enc_layers = nn.Sequential( + nn.Conv2d( + in_channels=config.input_shapes["observation.image"][0], + out_channels=config.image_encoder_hidden_dim, + kernel_size=7, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=5, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=3, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=3, + stride=2, + ), + nn.ReLU(), + ) + dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) + with torch.inference_mode(): + self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:] + self.image_enc_layers.extend( + nn.Sequential( + nn.Flatten(), + nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Tanh(), + ) + ) + + def forward(self, x): + return self.image_enc_layers(x) + + +class PretrainedImageEncoder(nn.Module): + def __init__(self, config): + super().__init__() + + self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config) + self.image_enc_proj = nn.Sequential( + nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Tanh(), + ) + + def _load_pretrained_vision_encoder(self, config): + """Set up CNN encoder""" + from transformers import AutoModel + + self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name) + # self.image_enc_layers.pooler = Identity() + + if hasattr(self.image_enc_layers.config, "hidden_sizes"): + self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension + elif hasattr(self.image_enc_layers, "fc"): + self.image_enc_out_shape = self.image_enc_layers.fc.in_features + else: + raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN") + return self.image_enc_layers, self.image_enc_out_shape + + def forward(self, x): + # TODO: (maractingi, azouitine) check the forward pass of the pretrained model + # doesn't reach the classifier layer because we don't need it + enc_feat = self.image_enc_layers(x).pooler_output + enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1)) + return enc_feat + + +def freeze_image_encoder(image_encoder: nn.Module): + """Freeze all parameters in the encoder""" + for param in image_encoder.parameters(): + param.requires_grad = False + + def orthogonal_init(): return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) +class Identity(nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return x + + # TODO (azouitine): I think in our case this function is not usefull we should remove it # after some investigation # borrowed from tdmpc @@ -626,3 +679,54 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens inp = torch.flatten(image_tensor, end_dim=-4) flat_out = fn(inp) return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) + + +if __name__ == "__main__": + # Test the SACObservationEncoder + import time + + config = SACConfig() + config.num_critics = 10 + encoder = SACObservationEncoder(config) + actor_encoder = SACObservationEncoder(config) + encoder = torch.compile(encoder) + critic_ensemble = CriticEnsemble( + encoder=encoder, + network_list=nn.ModuleList( + [ + MLP( + input_dim=encoder.output_dim + config.output_shapes["action"][0], + **config.critic_network_kwargs, + ) + for _ in range(config.num_critics) + ] + ), + ) + actor = Policy( + encoder=actor_encoder, + network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs), + action_dim=config.output_shapes["action"][0], + encoder_is_shared=config.shared_encoder, + **config.policy_kwargs, + ) + encoder = encoder.to("cuda:0") + critic_ensemble = torch.compile(critic_ensemble) + critic_ensemble = critic_ensemble.to("cuda:0") + actor = torch.compile(actor) + actor = actor.to("cuda:0") + obs_dict = { + "observation.image": torch.randn(1, 3, 84, 84), + "observation.state": torch.randn(1, 4), + } + actions = torch.randn(1, 2).to("cuda:0") + obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()} + print("compiling...") + # q_value = critic_ensemble(obs_dict, actions) + action = actor(obs_dict) + print("compiled") + start = time.perf_counter() + for _ in range(1000): + # features = encoder(obs_dict) + action = actor(obs_dict) + # q_value = critic_ensemble(obs_dict, actions) + print("Time taken:", time.perf_counter() - start) diff --git a/lerobot/configs/policy/sac_maniskill.yaml b/lerobot/configs/policy/sac_maniskill.yaml index 2776b39d..aaf59e53 100644 --- a/lerobot/configs/policy/sac_maniskill.yaml +++ b/lerobot/configs/policy/sac_maniskill.yaml @@ -52,6 +52,8 @@ policy: n_action_steps: 1 shared_encoder: true + # vision_encoder_name: null + freeze_vision_encoder: false input_shapes: # # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? observation.state: ["${env.state_dim}"] diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index 294f07a6..952590e8 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -191,6 +191,7 @@ def act_with_policy(cfg: DictConfig): # pretrained_policy_name_or_path=None, # device=device, # ) + policy = torch.compile(policy) assert isinstance(policy, nn.Module) # HACK for maniskill @@ -237,7 +238,9 @@ def act_with_policy(cfg: DictConfig): logging.debug("[ACTOR] Load new parameters from Learner.") state_dict = parameters_queue.get() state_dict = move_state_dict_to_device(state_dict, device=device) - policy.actor.load_state_dict(state_dict) + # strict=False for the case when the image encoder is frozen and not sent through + # the network. Becareful might cause issues if the wrong keys are passed + policy.actor.load_state_dict(state_dict, strict=False) if len(list_transition_to_send_to_learner) > 0: logging.debug( diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index dbafeb42..6dd33fed 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -259,6 +259,9 @@ def learner_push_parameters( while True: with policy_lock: params_dict = policy.actor.state_dict() + if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder: + params_dict = {k: v for k, v in params_dict if not k.startswith("encoder.")} + params_dict = move_state_dict_to_device(params_dict, device="cpu") # Serialize buf = io.BytesIO() @@ -541,6 +544,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No dataset_stats=None, pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, ) + # compile policy + policy = torch.compile(policy) assert isinstance(policy, nn.Module) optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy) diff --git a/lerobot/scripts/train_sac.py b/lerobot/scripts/train_sac.py index 936d65ee..4f7b55cc 100644 --- a/lerobot/scripts/train_sac.py +++ b/lerobot/scripts/train_sac.py @@ -13,26 +13,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging import functools -from pprint import pformat +import logging import random -from typing import Optional, Sequence, TypedDict, Callable +from pprint import pformat +from typing import Callable, Optional, Sequence, TypedDict import hydra import torch import torch.nn.functional as F -from torch import nn -from tqdm import tqdm from deepdiff import DeepDiff from omegaconf import DictConfig, OmegaConf - -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from torch import nn +from tqdm import tqdm # TODO: Remove the import of maniskill from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.envs.factory import make_env, make_maniskill_env -from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation +from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy from lerobot.common.policies.sac.modeling_sac import SACPolicy