Refactor SAC policy with performance optimizations and multi-camera support

- Introduced Ensemble and CriticHead classes for more efficient critic network handling
- Added support for multiple camera inputs in observation encoder
- Optimized image encoding by batching image processing
- Updated configuration for ManiSkill environment with reduced image size and action scaling
- Compiled critic networks for improved performance
- Simplified normalization and ensemble handling in critic networks
Co-authored-by: michel-aractingi <michel.aractingi@gmail.com>
This commit is contained in:
AdilZouitine 2025-02-20 17:14:27 +00:00
parent ff47c0b0d3
commit ff82367c62
4 changed files with 153 additions and 93 deletions

View File

@ -17,10 +17,12 @@
# TODO: (1) better device management # TODO: (1) better device management
from copy import deepcopy
from typing import Callable, Optional, Tuple from typing import Callable, Optional, Tuple
import einops import einops
import numpy as np import numpy as np
from tensordict import from_modules
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
@ -85,9 +87,9 @@ class SACPolicy(
self.critic_ensemble = CriticEnsemble( self.critic_ensemble = CriticEnsemble(
encoder=encoder_critic, encoder=encoder_critic,
network_list=nn.ModuleList( ensemble=Ensemble(
[ [
MLP( CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs, **config.critic_network_kwargs,
) )
@ -99,9 +101,9 @@ class SACPolicy(
self.critic_target = CriticEnsemble( self.critic_target = CriticEnsemble(
encoder=encoder_critic, encoder=encoder_critic,
network_list=nn.ModuleList( ensemble=Ensemble(
[ [
MLP( CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs, **config.critic_network_kwargs,
) )
@ -113,6 +115,9 @@ class SACPolicy(
self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
self.actor = Policy( self.actor = Policy(
encoder=encoder_actor, encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs), network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
@ -274,6 +279,35 @@ class MLP(nn.Module):
return self.net(x) return self.net(x)
class CriticHead(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
init_final: Optional[float] = None,
):
super().__init__()
self.net = MLP(
input_dim=input_dim,
hidden_dims=hidden_dims,
activations=activations,
activate_final=activate_final,
dropout_rate=dropout_rate,
)
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.output_layer.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.output_layer(self.net(x))
class CriticEnsemble(nn.Module): class CriticEnsemble(nn.Module):
""" """
@ -316,13 +350,13 @@ class CriticEnsemble(nn.Module):
def __init__( def __init__(
self, self,
encoder: Optional[nn.Module], encoder: Optional[nn.Module],
network_list: nn.ModuleList, ensemble: "Ensemble[CriticHead]",
output_normalization: nn.Module, output_normalization: nn.Module,
init_final: Optional[float] = None, init_final: Optional[float] = None,
): ):
super().__init__() super().__init__()
self.encoder = encoder self.encoder = encoder
self.network_list = network_list self.ensemble = ensemble
self.init_final = init_final self.init_final = init_final
self.output_normalization = output_normalization self.output_normalization = output_normalization
@ -330,31 +364,7 @@ class CriticEnsemble(nn.Module):
# Handle the case where a part of the encoder if frozen # Handle the case where a part of the encoder if frozen
if self.encoder is not None: if self.encoder is not None:
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize) self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
self.parameters_to_optimize += list(self.ensemble.parameters())
self.parameters_to_optimize += list(self.network_list.parameters())
# Find the last Linear layer's output dimension
for layer in reversed(network_list[0].net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
# Output layer
self.output_layers = []
if init_final is not None:
for _ in network_list:
output_layer = nn.Linear(out_features, 1)
nn.init.uniform_(output_layer.weight, -init_final, init_final)
nn.init.uniform_(output_layer.bias, -init_final, init_final)
self.output_layers.append(output_layer)
else:
self.output_layers = []
for _ in network_list:
output_layer = nn.Linear(out_features, 1)
orthogonal_init()(output_layer.weight)
self.output_layers.append(output_layer)
self.output_layers = nn.ModuleList(self.output_layers)
self.parameters_to_optimize += list(self.output_layers.parameters())
def forward( def forward(
self, self,
@ -373,12 +383,8 @@ class CriticEnsemble(nn.Module):
obs_enc = observations if self.encoder is None else self.encoder(observations) obs_enc = observations if self.encoder is None else self.encoder(observations)
inputs = torch.cat([obs_enc, actions], dim=-1) inputs = torch.cat([obs_enc, actions], dim=-1)
list_q_values = [] q_values = self.ensemble(inputs) # [num_critics, B, 1]
for network, output_layer in zip(self.network_list, self.output_layers, strict=False): return q_values.squeeze(-1) # [num_critics, B]
x = network(inputs)
value = output_layer(x)
list_q_values.append(value.squeeze(-1))
return torch.stack(list_q_values)
class Policy(nn.Module): class Policy(nn.Module):
@ -510,6 +516,7 @@ class SACObservationEncoder(nn.Module):
freeze_image_encoder(self.image_enc_layers) freeze_image_encoder(self.image_enc_layers)
else: else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters()) self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
if "observation.state" in config.input_shapes: if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential( self.state_enc_layers = nn.Sequential(
@ -546,14 +553,13 @@ class SACObservationEncoder(nn.Module):
""" """
feat = [] feat = []
obs_dict = self.input_normalization(obs_dict) obs_dict = self.input_normalization(obs_dict)
# Concatenate all images along the channel dimension. # Batch all images along the batch dimension, then encode them.
image_keys = [k for k in obs_dict if k.startswith("observation.image")] if len(self.all_image_keys) > 0:
for image_key in image_keys: images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
enc_feat = self.image_enc_layers(obs_dict[image_key]) images_batched = self.image_enc_layers(images_batched)
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
feat.extend(embeddings_chunks)
# if not self.has_pretrained_vision_encoder:
# enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
feat.append(enc_feat)
if "observation.environment_state" in self.config.input_shapes: if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes: if "observation.state" in self.config.input_shapes:
@ -671,6 +677,34 @@ class Identity(nn.Module):
return x return x
class Ensemble(nn.Module):
"""
Vectorized ensemble of modules.
"""
def __init__(self, modules, **kwargs):
super().__init__()
# combine_state_for_ensemble causes graph breaks
self.params = from_modules(*modules, as_module=True)
with self.params[0].data.to("meta").to_module(modules[0]):
self.module = deepcopy(modules[0])
self._repr = str(modules[0])
self._n = len(modules)
def __len__(self):
return self._n
def _call(self, params, *args, **kwargs):
with params.to_module(self.module):
return self.module(*args, **kwargs)
def forward(self, *args, **kwargs):
return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs)
def __repr__(self):
return f"Vectorized {len(self)}x " + self._repr
# TODO (azouitine): I think in our case this function is not usefull we should remove it # TODO (azouitine): I think in our case this function is not usefull we should remove it
# after some investigation # after some investigation
# borrowed from tdmpc # borrowed from tdmpc
@ -711,46 +745,68 @@ if __name__ == "__main__":
config = SACConfig() config = SACConfig()
config.num_critics = 10 config.num_critics = 10
encoder = SACObservationEncoder(config) config.vision_encoder_name = None
actor_encoder = SACObservationEncoder(config) encoder = SACObservationEncoder(config, nn.Identity())
encoder = torch.compile(encoder) # actor_encoder = SACObservationEncoder(config)
# encoder = torch.compile(encoder)
critic_ensemble = CriticEnsemble( critic_ensemble = CriticEnsemble(
encoder=encoder, encoder=encoder,
network_list=nn.ModuleList( ensemble=Ensemble(
[ [
MLP( CriticHead(
input_dim=encoder.output_dim + config.output_shapes["action"][0], input_dim=encoder.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs, **config.critic_network_kwargs,
) )
for _ in range(config.num_critics) for _ in range(config.num_critics)
] ]
), ),
output_normalization=nn.Identity(),
) )
actor = Policy( # actor = Policy(
encoder=actor_encoder, # encoder=actor_encoder,
network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs), # network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0], # action_dim=config.output_shapes["action"][0],
encoder_is_shared=config.shared_encoder, # encoder_is_shared=config.shared_encoder,
**config.policy_kwargs, # **config.policy_kwargs,
) # )
encoder = encoder.to("cuda:0") # encoder = encoder.to("cuda:0")
critic_ensemble = torch.compile(critic_ensemble) # critic_ensemble = torch.compile(critic_ensemble)
critic_ensemble = critic_ensemble.to("cuda:0") critic_ensemble = critic_ensemble.to("cuda:0")
actor = torch.compile(actor) # actor = torch.compile(actor)
actor = actor.to("cuda:0") # actor = actor.to("cuda:0")
obs_dict = { obs_dict = {
"observation.image": torch.randn(1, 3, 84, 84), "observation.image": torch.randn(8, 3, 84, 84),
"observation.state": torch.randn(1, 4), "observation.state": torch.randn(8, 4),
} }
actions = torch.randn(1, 2).to("cuda:0") actions = torch.randn(8, 2).to("cuda:0")
obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()} # obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
print("compiling...") # print("compiling...")
# q_value = critic_ensemble(obs_dict, actions) q_value = critic_ensemble(obs_dict, actions)
action = actor(obs_dict) print(q_value.size())
print("compiled") # action = actor(obs_dict)
# print("compiled")
# start = time.perf_counter()
# for _ in range(1000):
# # features = encoder(obs_dict)
# action = actor(obs_dict)
# # q_value = critic_ensemble(obs_dict, actions)
# print("Time taken:", time.perf_counter() - start)
# Compare the performance of the ensemble vs a for loop of 16 MLPs
ensemble = Ensemble([CriticHead(256, [256, 256]) for _ in range(2)])
ensemble = ensemble.to("cuda:0")
critic = CriticHead(256, [256, 256])
critic = critic.to("cuda:0")
data_ensemble = torch.randn(8, 256).to("cuda:0")
ensemble = torch.compile(ensemble)
# critic = torch.compile(critic)
print(ensemble(data_ensemble).size())
print(critic(data_ensemble).size())
start = time.perf_counter() start = time.perf_counter()
for _ in range(1000): for _ in range(1000):
# features = encoder(obs_dict) ensemble(data_ensemble)
action = actor(obs_dict) print("Time taken:", time.perf_counter() - start)
# q_value = critic_ensemble(obs_dict, actions) start = time.perf_counter()
for _ in range(1000):
for i in range(2):
critic(data_ensemble)
print("Time taken:", time.perf_counter() - start) print("Time taken:", time.perf_counter() - start)

View File

@ -5,14 +5,14 @@ fps: 20
env: env:
name: maniskill/pushcube name: maniskill/pushcube
task: PushCube-v1 task: PushCube-v1
image_size: 128 image_size: 64
control_mode: pd_ee_delta_pose control_mode: pd_ee_delta_pose
state_dim: 25 state_dim: 25
action_dim: 7 action_dim: 7
fps: ${fps} fps: ${fps}
obs: rgb obs: rgb
render_mode: rgb_array render_mode: rgb_array
render_size: 128 render_size: 64
device: cuda device: cuda
reward_classifier: reward_classifier:

View File

@ -59,32 +59,36 @@ policy:
input_shapes: input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? # # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"] observation.state: ["${env.state_dim}"]
observation.image: [3, 128, 128] observation.image: [3, 64, 64]
observation.image.2: [3, 64, 64]
output_shapes: output_shapes:
action: [7] action: [7]
# Normalization / Unnormalization camera_number: 2
input_normalization_modes:
observation.state: min_max
input_normalization_params:
observation.state:
min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
-3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
-6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400, # Normalization / Unnormalization
0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163, input_normalization_modes: null
7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135, # input_normalization_modes:
0.4001] # observation.state: min_max
input_normalization_params: null
# observation.state:
# min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
# 1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
# -3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
# -6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
# 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
# max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
# 0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
# 7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
# 0.4001]
output_normalization_modes: output_normalization_modes:
action: min_max action: min_max
output_normalization_params: output_normalization_params:
action: action:
min: [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0] min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
max: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0] max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
output_normalization_shapes: output_normalization_shapes:
action: [7] action: [7]
@ -94,8 +98,8 @@ policy:
# discount: 0.99 # discount: 0.99
discount: 0.80 discount: 0.80
temperature_init: 1.0 temperature_init: 1.0
num_critics: 2 #10 num_critics: 10 #10
num_subsample_critics: null num_subsample_critics: 2
critic_lr: 3e-4 critic_lr: 3e-4
actor_lr: 3e-4 actor_lr: 3e-4
temperature_lr: 3e-4 temperature_lr: 3e-4

View File

@ -10,7 +10,6 @@ from typing import Any
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]: def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
"""Convert environment observation to LeRobot format observation. """Convert environment observation to LeRobot format observation.
Args: Args:
@ -42,6 +41,7 @@ def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dic
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1) state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
return_observations["observation.image"] = img return_observations["observation.image"] = img
return_observations["observation.image.2"] = img
return_observations["observation.state"] = state return_observations["observation.state"] = state
return return_observations return return_observations
@ -142,7 +142,7 @@ def make_maniskill(
env.unwrapped.metadata["render_fps"] = 20 env.unwrapped.metadata["render_fps"] = 20
env = ManiSkillCompat(env) env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env) env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=10.0) env = ManiSkillMultiplyActionWrapper(env, multiply_factor=1)
return env return env