From ff82367c628548ffcc52a747f43ecf2e8a42e487 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Thu, 20 Feb 2025 17:14:27 +0000 Subject: [PATCH] Refactor SAC policy with performance optimizations and multi-camera support - Introduced Ensemble and CriticHead classes for more efficient critic network handling - Added support for multiple camera inputs in observation encoder - Optimized image encoding by batching image processing - Updated configuration for ManiSkill environment with reduced image size and action scaling - Compiled critic networks for improved performance - Simplified normalization and ensemble handling in critic networks Co-authored-by: michel-aractingi --- lerobot/common/policies/sac/modeling_sac.py | 198 +++++++++++------- lerobot/configs/env/maniskill_example.yaml | 4 +- lerobot/configs/policy/sac_maniskill.yaml | 40 ++-- .../scripts/server/maniskill_manipulator.py | 4 +- 4 files changed, 153 insertions(+), 93 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 84ff6081..7cb41ebd 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -17,10 +17,12 @@ # TODO: (1) better device management +from copy import deepcopy from typing import Callable, Optional, Tuple import einops import numpy as np +from tensordict import from_modules import torch import torch.nn as nn import torch.nn.functional as F # noqa: N812 @@ -85,9 +87,9 @@ class SACPolicy( self.critic_ensemble = CriticEnsemble( encoder=encoder_critic, - network_list=nn.ModuleList( + ensemble=Ensemble( [ - MLP( + CriticHead( input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], **config.critic_network_kwargs, ) @@ -99,9 +101,9 @@ class SACPolicy( self.critic_target = CriticEnsemble( encoder=encoder_critic, - network_list=nn.ModuleList( + ensemble=Ensemble( [ - MLP( + CriticHead( input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], **config.critic_network_kwargs, ) @@ -113,6 +115,9 @@ class SACPolicy( self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) + self.critic_ensemble = torch.compile(self.critic_ensemble) + self.critic_target = torch.compile(self.critic_target) + self.actor = Policy( encoder=encoder_actor, network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs), @@ -274,6 +279,35 @@ class MLP(nn.Module): return self.net(x) +class CriticHead(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dims: list[int], + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: Optional[float] = None, + init_final: Optional[float] = None, + ): + super().__init__() + self.net = MLP( + input_dim=input_dim, + hidden_dims=hidden_dims, + activations=activations, + activate_final=activate_final, + dropout_rate=dropout_rate, + ) + self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1) + if init_final is not None: + nn.init.uniform_(self.output_layer.weight, -init_final, init_final) + nn.init.uniform_(self.output_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.output_layer.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.output_layer(self.net(x)) + + class CriticEnsemble(nn.Module): """ ┌──────────────────┬─────────────────────────────────────────────────────────┐ @@ -316,13 +350,13 @@ class CriticEnsemble(nn.Module): def __init__( self, encoder: Optional[nn.Module], - network_list: nn.ModuleList, + ensemble: "Ensemble[CriticHead]", output_normalization: nn.Module, init_final: Optional[float] = None, ): super().__init__() self.encoder = encoder - self.network_list = network_list + self.ensemble = ensemble self.init_final = init_final self.output_normalization = output_normalization @@ -330,31 +364,7 @@ class CriticEnsemble(nn.Module): # Handle the case where a part of the encoder if frozen if self.encoder is not None: self.parameters_to_optimize += list(self.encoder.parameters_to_optimize) - - self.parameters_to_optimize += list(self.network_list.parameters()) - # Find the last Linear layer's output dimension - for layer in reversed(network_list[0].net): - if isinstance(layer, nn.Linear): - out_features = layer.out_features - break - - # Output layer - self.output_layers = [] - if init_final is not None: - for _ in network_list: - output_layer = nn.Linear(out_features, 1) - nn.init.uniform_(output_layer.weight, -init_final, init_final) - nn.init.uniform_(output_layer.bias, -init_final, init_final) - self.output_layers.append(output_layer) - else: - self.output_layers = [] - for _ in network_list: - output_layer = nn.Linear(out_features, 1) - orthogonal_init()(output_layer.weight) - self.output_layers.append(output_layer) - - self.output_layers = nn.ModuleList(self.output_layers) - self.parameters_to_optimize += list(self.output_layers.parameters()) + self.parameters_to_optimize += list(self.ensemble.parameters()) def forward( self, @@ -373,12 +383,8 @@ class CriticEnsemble(nn.Module): obs_enc = observations if self.encoder is None else self.encoder(observations) inputs = torch.cat([obs_enc, actions], dim=-1) - list_q_values = [] - for network, output_layer in zip(self.network_list, self.output_layers, strict=False): - x = network(inputs) - value = output_layer(x) - list_q_values.append(value.squeeze(-1)) - return torch.stack(list_q_values) + q_values = self.ensemble(inputs) # [num_critics, B, 1] + return q_values.squeeze(-1) # [num_critics, B] class Policy(nn.Module): @@ -510,6 +516,7 @@ class SACObservationEncoder(nn.Module): freeze_image_encoder(self.image_enc_layers) else: self.parameters_to_optimize += list(self.image_enc_layers.parameters()) + self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] if "observation.state" in config.input_shapes: self.state_enc_layers = nn.Sequential( @@ -546,14 +553,13 @@ class SACObservationEncoder(nn.Module): """ feat = [] obs_dict = self.input_normalization(obs_dict) - # Concatenate all images along the channel dimension. - image_keys = [k for k in obs_dict if k.startswith("observation.image")] - for image_key in image_keys: - enc_feat = self.image_enc_layers(obs_dict[image_key]) + # Batch all images along the batch dimension, then encode them. + if len(self.all_image_keys) > 0: + images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0) + images_batched = self.image_enc_layers(images_batched) + embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys)) + feat.extend(embeddings_chunks) - # if not self.has_pretrained_vision_encoder: - # enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]) - feat.append(enc_feat) if "observation.environment_state" in self.config.input_shapes: feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) if "observation.state" in self.config.input_shapes: @@ -671,6 +677,34 @@ class Identity(nn.Module): return x +class Ensemble(nn.Module): + """ + Vectorized ensemble of modules. + """ + + def __init__(self, modules, **kwargs): + super().__init__() + # combine_state_for_ensemble causes graph breaks + self.params = from_modules(*modules, as_module=True) + with self.params[0].data.to("meta").to_module(modules[0]): + self.module = deepcopy(modules[0]) + self._repr = str(modules[0]) + self._n = len(modules) + + def __len__(self): + return self._n + + def _call(self, params, *args, **kwargs): + with params.to_module(self.module): + return self.module(*args, **kwargs) + + def forward(self, *args, **kwargs): + return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs) + + def __repr__(self): + return f"Vectorized {len(self)}x " + self._repr + + # TODO (azouitine): I think in our case this function is not usefull we should remove it # after some investigation # borrowed from tdmpc @@ -711,46 +745,68 @@ if __name__ == "__main__": config = SACConfig() config.num_critics = 10 - encoder = SACObservationEncoder(config) - actor_encoder = SACObservationEncoder(config) - encoder = torch.compile(encoder) + config.vision_encoder_name = None + encoder = SACObservationEncoder(config, nn.Identity()) + # actor_encoder = SACObservationEncoder(config) + # encoder = torch.compile(encoder) critic_ensemble = CriticEnsemble( encoder=encoder, - network_list=nn.ModuleList( + ensemble=Ensemble( [ - MLP( + CriticHead( input_dim=encoder.output_dim + config.output_shapes["action"][0], **config.critic_network_kwargs, ) for _ in range(config.num_critics) ] ), + output_normalization=nn.Identity(), ) - actor = Policy( - encoder=actor_encoder, - network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs), - action_dim=config.output_shapes["action"][0], - encoder_is_shared=config.shared_encoder, - **config.policy_kwargs, - ) - encoder = encoder.to("cuda:0") - critic_ensemble = torch.compile(critic_ensemble) + # actor = Policy( + # encoder=actor_encoder, + # network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs), + # action_dim=config.output_shapes["action"][0], + # encoder_is_shared=config.shared_encoder, + # **config.policy_kwargs, + # ) + # encoder = encoder.to("cuda:0") + # critic_ensemble = torch.compile(critic_ensemble) critic_ensemble = critic_ensemble.to("cuda:0") - actor = torch.compile(actor) - actor = actor.to("cuda:0") + # actor = torch.compile(actor) + # actor = actor.to("cuda:0") obs_dict = { - "observation.image": torch.randn(1, 3, 84, 84), - "observation.state": torch.randn(1, 4), + "observation.image": torch.randn(8, 3, 84, 84), + "observation.state": torch.randn(8, 4), } - actions = torch.randn(1, 2).to("cuda:0") - obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()} - print("compiling...") - # q_value = critic_ensemble(obs_dict, actions) - action = actor(obs_dict) - print("compiled") + actions = torch.randn(8, 2).to("cuda:0") + # obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()} + # print("compiling...") + q_value = critic_ensemble(obs_dict, actions) + print(q_value.size()) + # action = actor(obs_dict) + # print("compiled") + # start = time.perf_counter() + # for _ in range(1000): + # # features = encoder(obs_dict) + # action = actor(obs_dict) + # # q_value = critic_ensemble(obs_dict, actions) + # print("Time taken:", time.perf_counter() - start) + # Compare the performance of the ensemble vs a for loop of 16 MLPs + ensemble = Ensemble([CriticHead(256, [256, 256]) for _ in range(2)]) + ensemble = ensemble.to("cuda:0") + critic = CriticHead(256, [256, 256]) + critic = critic.to("cuda:0") + data_ensemble = torch.randn(8, 256).to("cuda:0") + ensemble = torch.compile(ensemble) + # critic = torch.compile(critic) + print(ensemble(data_ensemble).size()) + print(critic(data_ensemble).size()) start = time.perf_counter() for _ in range(1000): - # features = encoder(obs_dict) - action = actor(obs_dict) - # q_value = critic_ensemble(obs_dict, actions) + ensemble(data_ensemble) + print("Time taken:", time.perf_counter() - start) + start = time.perf_counter() + for _ in range(1000): + for i in range(2): + critic(data_ensemble) print("Time taken:", time.perf_counter() - start) diff --git a/lerobot/configs/env/maniskill_example.yaml b/lerobot/configs/env/maniskill_example.yaml index 03814614..2b9966c9 100644 --- a/lerobot/configs/env/maniskill_example.yaml +++ b/lerobot/configs/env/maniskill_example.yaml @@ -5,14 +5,14 @@ fps: 20 env: name: maniskill/pushcube task: PushCube-v1 - image_size: 128 + image_size: 64 control_mode: pd_ee_delta_pose state_dim: 25 action_dim: 7 fps: ${fps} obs: rgb render_mode: rgb_array - render_size: 128 + render_size: 64 device: cuda reward_classifier: diff --git a/lerobot/configs/policy/sac_maniskill.yaml b/lerobot/configs/policy/sac_maniskill.yaml index 3edf7d67..e657434a 100644 --- a/lerobot/configs/policy/sac_maniskill.yaml +++ b/lerobot/configs/policy/sac_maniskill.yaml @@ -59,32 +59,36 @@ policy: input_shapes: # # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? observation.state: ["${env.state_dim}"] - observation.image: [3, 128, 128] + observation.image: [3, 64, 64] + observation.image.2: [3, 64, 64] output_shapes: action: [7] + + camera_number: 2 # Normalization / Unnormalization - input_normalization_modes: - observation.state: min_max - input_normalization_params: - observation.state: - min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01, - 1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00, - -3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00, - -6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01, - 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01] + input_normalization_modes: null + # input_normalization_modes: + # observation.state: min_max + input_normalization_params: null + # observation.state: + # min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01, + # 1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00, + # -3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00, + # -6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01, + # 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01] - max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400, - 0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163, - 7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135, - 0.4001] + # max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400, + # 0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163, + # 7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135, + # 0.4001] output_normalization_modes: action: min_max output_normalization_params: action: - min: [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0] - max: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0] + min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0] + max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] output_normalization_shapes: action: [7] @@ -94,8 +98,8 @@ policy: # discount: 0.99 discount: 0.80 temperature_init: 1.0 - num_critics: 2 #10 - num_subsample_critics: null + num_critics: 10 #10 + num_subsample_critics: 2 critic_lr: 3e-4 actor_lr: 3e-4 temperature_lr: 3e-4 diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py index b50698a9..105deeb4 100644 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -10,7 +10,6 @@ from typing import Any from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv - def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]: """Convert environment observation to LeRobot format observation. Args: @@ -42,6 +41,7 @@ def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dic state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1) return_observations["observation.image"] = img + return_observations["observation.image.2"] = img return_observations["observation.state"] = state return return_observations @@ -142,7 +142,7 @@ def make_maniskill( env.unwrapped.metadata["render_fps"] = 20 env = ManiSkillCompat(env) env = ManiSkillActionWrapper(env) - env = ManiSkillMultiplyActionWrapper(env, multiply_factor=10.0) + env = ManiSkillMultiplyActionWrapper(env, multiply_factor=1) return env