- Refactor observation encoder in `modeling_sac.py`

- added `torch.compile` to the actor and learner servers.
- organized imports in `train_sac.py`
- optimized the parameters push by not sending the frozen pre-trained encoder.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-01-31 16:45:52 +00:00 committed by Adil Zouitine
parent e35f8ed8a8
commit d1cc9665da
6 changed files with 199 additions and 85 deletions

View File

@ -55,9 +55,10 @@ class SACConfig:
) )
camera_number: int = 1 camera_number: int = 1
# Add type annotations for these fields: # Add type annotations for these fields:
vision_encoder_name: str = field(default="microsoft/resnet-18") vision_encoder_name: str | None = field(default="microsoft/resnet-18")
freeze_vision_encoder: bool = True
image_encoder_hidden_dim: int = 32 image_encoder_hidden_dim: int = 32
shared_encoder: bool = False shared_encoder: bool = True
discount: float = 0.99 discount: float = 0.99
temperature_init: float = 1.0 temperature_init: float = 1.0
num_critics: int = 2 num_critics: int = 2

View File

@ -312,7 +312,7 @@ class CriticEnsemble(nn.Module):
def __init__( def __init__(
self, self,
encoder: Optional[nn.Module], encoder: Optional[nn.Module],
network_list: nn.Module, network_list: nn.ModuleList,
init_final: Optional[float] = None, init_final: Optional[float] = None,
): ):
super().__init__() super().__init__()
@ -320,6 +320,12 @@ class CriticEnsemble(nn.Module):
self.network_list = network_list self.network_list = network_list
self.init_final = init_final self.init_final = init_final
self.parameters_to_optimize = []
# Handle the case where a part of the encoder if frozen
if self.encoder is not None:
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
self.parameters_to_optimize += list(self.network_list.parameters())
# Find the last Linear layer's output dimension # Find the last Linear layer's output dimension
for layer in reversed(network_list[0].net): for layer in reversed(network_list[0].net):
if isinstance(layer, nn.Linear): if isinstance(layer, nn.Linear):
@ -342,6 +348,7 @@ class CriticEnsemble(nn.Module):
self.output_layers.append(output_layer) self.output_layers.append(output_layer)
self.output_layers = nn.ModuleList(self.output_layers) self.output_layers = nn.ModuleList(self.output_layers)
self.parameters_to_optimize += list(self.output_layers.parameters())
def forward( def forward(
self, self,
@ -474,19 +481,86 @@ class SACObservationEncoder(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.has_pretrained_vision_encoder = False self.has_pretrained_vision_encoder = False
self.parameters_to_optimize = []
self.aggregation_size: int = 0
if "observation.image" in config.input_shapes: if "observation.image" in config.input_shapes:
self.camera_number = config.camera_number self.camera_number = config.camera_number
self.aggregation_size: int = 0
if self.config.vision_encoder_name is not None: if self.config.vision_encoder_name is not None:
self.image_enc_layers = PretrainedImageEncoder(config)
self.has_pretrained_vision_encoder = True self.has_pretrained_vision_encoder = True
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder() else:
self.freeze_encoder() self.image_enc_layers = DefaultImageEncoder(config)
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), self.aggregation_size += config.latent_dim * self.camera_number
nn.LayerNorm(config.latent_dim),
if config.freeze_vision_encoder:
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_shapes["observation.state"][0], out_features=config.latent_dim
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(), nn.Tanh(),
) )
else: self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_shapes["observation.environment_state"][0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
over all features.
"""
feat = []
# Concatenate all images along the channel dimension.
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
for image_key in image_keys:
enc_feat = self.image_enc_layers(obs_dict[image_key])
# if not self.has_pretrained_vision_encoder:
# enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
feat.append(enc_feat)
if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
features = torch.cat(tensors=feat, dim=-1)
features = self.aggregation_layer(features)
return features
@property
def output_dim(self) -> int:
"""Returns the dimension of the encoder output"""
return self.config.latent_dim
class DefaultImageEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.image_enc_layers = nn.Sequential( self.image_enc_layers = nn.Sequential(
nn.Conv2d( nn.Conv2d(
in_channels=config.input_shapes["observation.image"][0], in_channels=config.input_shapes["observation.image"][0],
@ -528,34 +602,29 @@ class SACObservationEncoder(nn.Module):
nn.Tanh(), nn.Tanh(),
) )
) )
self.aggregation_size += config.latent_dim * self.camera_number
if "observation.state" in config.input_shapes: def forward(self, x):
self.state_enc_layers = nn.Sequential( return self.image_enc_layers(x)
nn.Linear(
in_features=config.input_shapes["observation.state"][0], out_features=config.latent_dim
), class PretrainedImageEncoder(nn.Module):
nn.LayerNorm(normalized_shape=config.latent_dim), def __init__(self, config):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(), nn.Tanh(),
) )
self.aggregation_size += config.latent_dim
if "observation.environment_state" in config.input_shapes: def _load_pretrained_vision_encoder(self, config):
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_shapes["observation.environment_state"][0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
def _load_pretrained_vision_encoder(self):
"""Set up CNN encoder""" """Set up CNN encoder"""
from transformers import AutoModel from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained(self.config.vision_encoder_name) self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name)
# self.image_enc_layers.pooler = Identity()
if hasattr(self.image_enc_layers.config, "hidden_sizes"): if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"): elif hasattr(self.image_enc_layers, "fc"):
@ -564,48 +633,32 @@ class SACObservationEncoder(nn.Module):
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN") raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
return self.image_enc_layers, self.image_enc_out_shape return self.image_enc_layers, self.image_enc_out_shape
def freeze_encoder(self): def forward(self, x):
"""Freeze all parameters in the encoder""" # TODO: (maractingi, azouitine) check the forward pass of the pretrained model
for param in self.image_enc_layers.parameters(): # doesn't reach the classifier layer because we don't need it
param.requires_grad = False enc_feat = self.image_enc_layers(x).pooler_output
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
over all features.
"""
feat = []
# Concatenate all images along the channel dimension.
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
for image_key in image_keys:
if self.has_pretrained_vision_encoder:
enc_feat = self.image_enc_layers(obs_dict[image_key]).pooler_output
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1)) enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
else: return enc_feat
enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
feat.append(enc_feat)
if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
features = torch.cat(tensors=feat, dim=-1) def freeze_image_encoder(image_encoder: nn.Module):
features = self.aggregation_layer(features) """Freeze all parameters in the encoder"""
for param in image_encoder.parameters():
return features param.requires_grad = False
@property
def output_dim(self) -> int:
"""Returns the dimension of the encoder output"""
return self.config.latent_dim
def orthogonal_init(): def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
# TODO (azouitine): I think in our case this function is not usefull we should remove it # TODO (azouitine): I think in our case this function is not usefull we should remove it
# after some investigation # after some investigation
# borrowed from tdmpc # borrowed from tdmpc
@ -626,3 +679,54 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
inp = torch.flatten(image_tensor, end_dim=-4) inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp) flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
if __name__ == "__main__":
# Test the SACObservationEncoder
import time
config = SACConfig()
config.num_critics = 10
encoder = SACObservationEncoder(config)
actor_encoder = SACObservationEncoder(config)
encoder = torch.compile(encoder)
critic_ensemble = CriticEnsemble(
encoder=encoder,
network_list=nn.ModuleList(
[
MLP(
input_dim=encoder.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
),
)
actor = Policy(
encoder=actor_encoder,
network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
)
encoder = encoder.to("cuda:0")
critic_ensemble = torch.compile(critic_ensemble)
critic_ensemble = critic_ensemble.to("cuda:0")
actor = torch.compile(actor)
actor = actor.to("cuda:0")
obs_dict = {
"observation.image": torch.randn(1, 3, 84, 84),
"observation.state": torch.randn(1, 4),
}
actions = torch.randn(1, 2).to("cuda:0")
obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
print("compiling...")
# q_value = critic_ensemble(obs_dict, actions)
action = actor(obs_dict)
print("compiled")
start = time.perf_counter()
for _ in range(1000):
# features = encoder(obs_dict)
action = actor(obs_dict)
# q_value = critic_ensemble(obs_dict, actions)
print("Time taken:", time.perf_counter() - start)

View File

@ -52,6 +52,8 @@ policy:
n_action_steps: 1 n_action_steps: 1
shared_encoder: true shared_encoder: true
# vision_encoder_name: null
freeze_vision_encoder: false
input_shapes: input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? # # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"] observation.state: ["${env.state_dim}"]

View File

@ -191,6 +191,7 @@ def act_with_policy(cfg: DictConfig):
# pretrained_policy_name_or_path=None, # pretrained_policy_name_or_path=None,
# device=device, # device=device,
# ) # )
policy = torch.compile(policy)
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
# HACK for maniskill # HACK for maniskill
@ -237,7 +238,9 @@ def act_with_policy(cfg: DictConfig):
logging.debug("[ACTOR] Load new parameters from Learner.") logging.debug("[ACTOR] Load new parameters from Learner.")
state_dict = parameters_queue.get() state_dict = parameters_queue.get()
state_dict = move_state_dict_to_device(state_dict, device=device) state_dict = move_state_dict_to_device(state_dict, device=device)
policy.actor.load_state_dict(state_dict) # strict=False for the case when the image encoder is frozen and not sent through
# the network. Becareful might cause issues if the wrong keys are passed
policy.actor.load_state_dict(state_dict, strict=False)
if len(list_transition_to_send_to_learner) > 0: if len(list_transition_to_send_to_learner) > 0:
logging.debug( logging.debug(

View File

@ -259,6 +259,9 @@ def learner_push_parameters(
while True: while True:
with policy_lock: with policy_lock:
params_dict = policy.actor.state_dict() params_dict = policy.actor.state_dict()
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
params_dict = {k: v for k, v in params_dict if not k.startswith("encoder.")}
params_dict = move_state_dict_to_device(params_dict, device="cpu") params_dict = move_state_dict_to_device(params_dict, device="cpu")
# Serialize # Serialize
buf = io.BytesIO() buf = io.BytesIO()
@ -541,6 +544,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
dataset_stats=None, dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
) )
# compile policy
policy = torch.compile(policy)
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy) optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)

View File

@ -13,26 +13,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import functools import functools
from pprint import pformat import logging
import random import random
from typing import Optional, Sequence, TypedDict, Callable from pprint import pformat
from typing import Callable, Optional, Sequence, TypedDict
import hydra import hydra
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from deepdiff import DeepDiff from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from torch import nn
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from tqdm import tqdm
# TODO: Remove the import of maniskill # TODO: Remove the import of maniskill
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.envs.factory import make_env, make_maniskill_env from lerobot.common.envs.factory import make_env, make_maniskill_env
from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation
from lerobot.common.logger import Logger, log_output_dir from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy