Refactor SAC policy with performance optimizations and multi-camera support
- Introduced Ensemble and CriticHead classes for more efficient critic network handling - Added support for multiple camera inputs in observation encoder - Optimized image encoding by batching image processing - Updated configuration for ManiSkill environment with reduced image size and action scaling - Compiled critic networks for improved performance - Simplified normalization and ensemble handling in critic networks Co-authored-by: michel-aractingi <michel.aractingi@gmail.com>
This commit is contained in:
parent
ff47c0b0d3
commit
ff82367c62
|
@ -17,10 +17,12 @@
|
|||
|
||||
# TODO: (1) better device management
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
from tensordict import from_modules
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
|
@ -85,9 +87,9 @@ class SACPolicy(
|
|||
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
network_list=nn.ModuleList(
|
||||
ensemble=Ensemble(
|
||||
[
|
||||
MLP(
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
|
@ -99,9 +101,9 @@ class SACPolicy(
|
|||
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
network_list=nn.ModuleList(
|
||||
ensemble=Ensemble(
|
||||
[
|
||||
MLP(
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
|
@ -113,6 +115,9 @@ class SACPolicy(
|
|||
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
||||
|
@ -274,6 +279,35 @@ class MLP(nn.Module):
|
|||
return self.net(x)
|
||||
|
||||
|
||||
class CriticHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: Optional[float] = None,
|
||||
init_final: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.net = MLP(
|
||||
input_dim=input_dim,
|
||||
hidden_dims=hidden_dims,
|
||||
activations=activations,
|
||||
activate_final=activate_final,
|
||||
dropout_rate=dropout_rate,
|
||||
)
|
||||
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.output_layer(self.net(x))
|
||||
|
||||
|
||||
class CriticEnsemble(nn.Module):
|
||||
"""
|
||||
┌──────────────────┬─────────────────────────────────────────────────────────┐
|
||||
|
@ -316,13 +350,13 @@ class CriticEnsemble(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network_list: nn.ModuleList,
|
||||
ensemble: "Ensemble[CriticHead]",
|
||||
output_normalization: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.network_list = network_list
|
||||
self.ensemble = ensemble
|
||||
self.init_final = init_final
|
||||
self.output_normalization = output_normalization
|
||||
|
||||
|
@ -330,31 +364,7 @@ class CriticEnsemble(nn.Module):
|
|||
# Handle the case where a part of the encoder if frozen
|
||||
if self.encoder is not None:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
|
||||
|
||||
self.parameters_to_optimize += list(self.network_list.parameters())
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network_list[0].net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
# Output layer
|
||||
self.output_layers = []
|
||||
if init_final is not None:
|
||||
for _ in network_list:
|
||||
output_layer = nn.Linear(out_features, 1)
|
||||
nn.init.uniform_(output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(output_layer.bias, -init_final, init_final)
|
||||
self.output_layers.append(output_layer)
|
||||
else:
|
||||
self.output_layers = []
|
||||
for _ in network_list:
|
||||
output_layer = nn.Linear(out_features, 1)
|
||||
orthogonal_init()(output_layer.weight)
|
||||
self.output_layers.append(output_layer)
|
||||
|
||||
self.output_layers = nn.ModuleList(self.output_layers)
|
||||
self.parameters_to_optimize += list(self.output_layers.parameters())
|
||||
self.parameters_to_optimize += list(self.ensemble.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -373,12 +383,8 @@ class CriticEnsemble(nn.Module):
|
|||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
list_q_values = []
|
||||
for network, output_layer in zip(self.network_list, self.output_layers, strict=False):
|
||||
x = network(inputs)
|
||||
value = output_layer(x)
|
||||
list_q_values.append(value.squeeze(-1))
|
||||
return torch.stack(list_q_values)
|
||||
q_values = self.ensemble(inputs) # [num_critics, B, 1]
|
||||
return q_values.squeeze(-1) # [num_critics, B]
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
|
@ -510,6 +516,7 @@ class SACObservationEncoder(nn.Module):
|
|||
freeze_image_encoder(self.image_enc_layers)
|
||||
else:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
|
@ -546,14 +553,13 @@ class SACObservationEncoder(nn.Module):
|
|||
"""
|
||||
feat = []
|
||||
obs_dict = self.input_normalization(obs_dict)
|
||||
# Concatenate all images along the channel dimension.
|
||||
image_keys = [k for k in obs_dict if k.startswith("observation.image")]
|
||||
for image_key in image_keys:
|
||||
enc_feat = self.image_enc_layers(obs_dict[image_key])
|
||||
# Batch all images along the batch dimension, then encode them.
|
||||
if len(self.all_image_keys) > 0:
|
||||
images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
|
||||
images_batched = self.image_enc_layers(images_batched)
|
||||
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
||||
feat.extend(embeddings_chunks)
|
||||
|
||||
# if not self.has_pretrained_vision_encoder:
|
||||
# enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
|
||||
feat.append(enc_feat)
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
|
@ -671,6 +677,34 @@ class Identity(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class Ensemble(nn.Module):
|
||||
"""
|
||||
Vectorized ensemble of modules.
|
||||
"""
|
||||
|
||||
def __init__(self, modules, **kwargs):
|
||||
super().__init__()
|
||||
# combine_state_for_ensemble causes graph breaks
|
||||
self.params = from_modules(*modules, as_module=True)
|
||||
with self.params[0].data.to("meta").to_module(modules[0]):
|
||||
self.module = deepcopy(modules[0])
|
||||
self._repr = str(modules[0])
|
||||
self._n = len(modules)
|
||||
|
||||
def __len__(self):
|
||||
return self._n
|
||||
|
||||
def _call(self, params, *args, **kwargs):
|
||||
with params.to_module(self.module):
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Vectorized {len(self)}x " + self._repr
|
||||
|
||||
|
||||
# TODO (azouitine): I think in our case this function is not usefull we should remove it
|
||||
# after some investigation
|
||||
# borrowed from tdmpc
|
||||
|
@ -711,46 +745,68 @@ if __name__ == "__main__":
|
|||
|
||||
config = SACConfig()
|
||||
config.num_critics = 10
|
||||
encoder = SACObservationEncoder(config)
|
||||
actor_encoder = SACObservationEncoder(config)
|
||||
encoder = torch.compile(encoder)
|
||||
config.vision_encoder_name = None
|
||||
encoder = SACObservationEncoder(config, nn.Identity())
|
||||
# actor_encoder = SACObservationEncoder(config)
|
||||
# encoder = torch.compile(encoder)
|
||||
critic_ensemble = CriticEnsemble(
|
||||
encoder=encoder,
|
||||
network_list=nn.ModuleList(
|
||||
ensemble=Ensemble(
|
||||
[
|
||||
MLP(
|
||||
CriticHead(
|
||||
input_dim=encoder.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
),
|
||||
output_normalization=nn.Identity(),
|
||||
)
|
||||
actor = Policy(
|
||||
encoder=actor_encoder,
|
||||
network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
encoder = encoder.to("cuda:0")
|
||||
critic_ensemble = torch.compile(critic_ensemble)
|
||||
# actor = Policy(
|
||||
# encoder=actor_encoder,
|
||||
# network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs),
|
||||
# action_dim=config.output_shapes["action"][0],
|
||||
# encoder_is_shared=config.shared_encoder,
|
||||
# **config.policy_kwargs,
|
||||
# )
|
||||
# encoder = encoder.to("cuda:0")
|
||||
# critic_ensemble = torch.compile(critic_ensemble)
|
||||
critic_ensemble = critic_ensemble.to("cuda:0")
|
||||
actor = torch.compile(actor)
|
||||
actor = actor.to("cuda:0")
|
||||
# actor = torch.compile(actor)
|
||||
# actor = actor.to("cuda:0")
|
||||
obs_dict = {
|
||||
"observation.image": torch.randn(1, 3, 84, 84),
|
||||
"observation.state": torch.randn(1, 4),
|
||||
"observation.image": torch.randn(8, 3, 84, 84),
|
||||
"observation.state": torch.randn(8, 4),
|
||||
}
|
||||
actions = torch.randn(1, 2).to("cuda:0")
|
||||
obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
|
||||
print("compiling...")
|
||||
# q_value = critic_ensemble(obs_dict, actions)
|
||||
action = actor(obs_dict)
|
||||
print("compiled")
|
||||
actions = torch.randn(8, 2).to("cuda:0")
|
||||
# obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
|
||||
# print("compiling...")
|
||||
q_value = critic_ensemble(obs_dict, actions)
|
||||
print(q_value.size())
|
||||
# action = actor(obs_dict)
|
||||
# print("compiled")
|
||||
# start = time.perf_counter()
|
||||
# for _ in range(1000):
|
||||
# # features = encoder(obs_dict)
|
||||
# action = actor(obs_dict)
|
||||
# # q_value = critic_ensemble(obs_dict, actions)
|
||||
# print("Time taken:", time.perf_counter() - start)
|
||||
# Compare the performance of the ensemble vs a for loop of 16 MLPs
|
||||
ensemble = Ensemble([CriticHead(256, [256, 256]) for _ in range(2)])
|
||||
ensemble = ensemble.to("cuda:0")
|
||||
critic = CriticHead(256, [256, 256])
|
||||
critic = critic.to("cuda:0")
|
||||
data_ensemble = torch.randn(8, 256).to("cuda:0")
|
||||
ensemble = torch.compile(ensemble)
|
||||
# critic = torch.compile(critic)
|
||||
print(ensemble(data_ensemble).size())
|
||||
print(critic(data_ensemble).size())
|
||||
start = time.perf_counter()
|
||||
for _ in range(1000):
|
||||
# features = encoder(obs_dict)
|
||||
action = actor(obs_dict)
|
||||
# q_value = critic_ensemble(obs_dict, actions)
|
||||
ensemble(data_ensemble)
|
||||
print("Time taken:", time.perf_counter() - start)
|
||||
start = time.perf_counter()
|
||||
for _ in range(1000):
|
||||
for i in range(2):
|
||||
critic(data_ensemble)
|
||||
print("Time taken:", time.perf_counter() - start)
|
||||
|
|
|
@ -5,14 +5,14 @@ fps: 20
|
|||
env:
|
||||
name: maniskill/pushcube
|
||||
task: PushCube-v1
|
||||
image_size: 128
|
||||
image_size: 64
|
||||
control_mode: pd_ee_delta_pose
|
||||
state_dim: 25
|
||||
action_dim: 7
|
||||
fps: ${fps}
|
||||
obs: rgb
|
||||
render_mode: rgb_array
|
||||
render_size: 128
|
||||
render_size: 64
|
||||
device: cuda
|
||||
|
||||
reward_classifier:
|
||||
|
|
|
@ -59,32 +59,36 @@ policy:
|
|||
input_shapes:
|
||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.state: ["${env.state_dim}"]
|
||||
observation.image: [3, 128, 128]
|
||||
observation.image: [3, 64, 64]
|
||||
observation.image.2: [3, 64, 64]
|
||||
output_shapes:
|
||||
action: [7]
|
||||
|
||||
camera_number: 2
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.state: min_max
|
||||
input_normalization_params:
|
||||
observation.state:
|
||||
min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
|
||||
1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
|
||||
-3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
|
||||
-6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
|
||||
8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
|
||||
input_normalization_modes: null
|
||||
# input_normalization_modes:
|
||||
# observation.state: min_max
|
||||
input_normalization_params: null
|
||||
# observation.state:
|
||||
# min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
|
||||
# 1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
|
||||
# -3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
|
||||
# -6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
|
||||
# 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
|
||||
|
||||
max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
|
||||
0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
|
||||
7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
|
||||
0.4001]
|
||||
# max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
|
||||
# 0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
|
||||
# 7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
|
||||
# 0.4001]
|
||||
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
output_normalization_params:
|
||||
action:
|
||||
min: [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0]
|
||||
max: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
|
||||
min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
|
||||
max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
output_normalization_shapes:
|
||||
action: [7]
|
||||
|
||||
|
@ -94,8 +98,8 @@ policy:
|
|||
# discount: 0.99
|
||||
discount: 0.80
|
||||
temperature_init: 1.0
|
||||
num_critics: 2 #10
|
||||
num_subsample_critics: null
|
||||
num_critics: 10 #10
|
||||
num_subsample_critics: 2
|
||||
critic_lr: 3e-4
|
||||
actor_lr: 3e-4
|
||||
temperature_lr: 3e-4
|
||||
|
|
|
@ -10,7 +10,6 @@ from typing import Any
|
|||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
|
@ -42,6 +41,7 @@ def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dic
|
|||
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
|
||||
|
||||
return_observations["observation.image"] = img
|
||||
return_observations["observation.image.2"] = img
|
||||
return_observations["observation.state"] = state
|
||||
return return_observations
|
||||
|
||||
|
@ -142,7 +142,7 @@ def make_maniskill(
|
|||
env.unwrapped.metadata["render_fps"] = 20
|
||||
env = ManiSkillCompat(env)
|
||||
env = ManiSkillActionWrapper(env)
|
||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=10.0)
|
||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=1)
|
||||
|
||||
return env
|
||||
|
||||
|
|
Loading…
Reference in New Issue