- Refactor observation encoder in `modeling_sac.py`

- added `torch.compile` to the actor and learner servers.
- organized imports in `train_sac.py`
- optimized the parameters push by not sending the frozen pre-trained encoder.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-01-31 16:45:52 +00:00
parent f1c8bfe01e
commit 506821c7df
6 changed files with 199 additions and 85 deletions

View File

@ -55,9 +55,10 @@ class SACConfig:
) )
camera_number: int = 1 camera_number: int = 1
# Add type annotations for these fields: # Add type annotations for these fields:
vision_encoder_name: str = field(default="microsoft/resnet-18") vision_encoder_name: str | None = field(default="microsoft/resnet-18")
freeze_vision_encoder: bool = True
image_encoder_hidden_dim: int = 32 image_encoder_hidden_dim: int = 32
shared_encoder: bool = False shared_encoder: bool = True
discount: float = 0.99 discount: float = 0.99
temperature_init: float = 1.0 temperature_init: float = 1.0
num_critics: int = 2 num_critics: int = 2

View File

@ -312,7 +312,7 @@ class CriticEnsemble(nn.Module):
def __init__( def __init__(
self, self,
encoder: Optional[nn.Module], encoder: Optional[nn.Module],
network_list: nn.Module, network_list: nn.ModuleList,
init_final: Optional[float] = None, init_final: Optional[float] = None,
): ):
super().__init__() super().__init__()
@ -320,6 +320,12 @@ class CriticEnsemble(nn.Module):
self.network_list = network_list self.network_list = network_list
self.init_final = init_final self.init_final = init_final
self.parameters_to_optimize = []
# Handle the case where a part of the encoder if frozen
if self.encoder is not None:
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
self.parameters_to_optimize += list(self.network_list.parameters())
# Find the last Linear layer's output dimension # Find the last Linear layer's output dimension
for layer in reversed(network_list[0].net): for layer in reversed(network_list[0].net):
if isinstance(layer, nn.Linear): if isinstance(layer, nn.Linear):
@ -342,6 +348,7 @@ class CriticEnsemble(nn.Module):
self.output_layers.append(output_layer) self.output_layers.append(output_layer)
self.output_layers = nn.ModuleList(self.output_layers) self.output_layers = nn.ModuleList(self.output_layers)
self.parameters_to_optimize += list(self.output_layers.parameters())
def forward( def forward(
self, self,
@ -474,61 +481,25 @@ class SACObservationEncoder(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.has_pretrained_vision_encoder = False self.has_pretrained_vision_encoder = False
self.parameters_to_optimize = []
self.aggregation_size: int = 0
if "observation.image" in config.input_shapes: if "observation.image" in config.input_shapes:
self.camera_number = config.camera_number self.camera_number = config.camera_number
self.aggregation_size: int = 0
if self.config.vision_encoder_name is not None: if self.config.vision_encoder_name is not None:
self.image_enc_layers = PretrainedImageEncoder(config)
self.has_pretrained_vision_encoder = True self.has_pretrained_vision_encoder = True
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder()
self.freeze_encoder()
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
else: else:
self.image_enc_layers = nn.Sequential( self.image_enc_layers = DefaultImageEncoder(config)
nn.Conv2d(
in_channels=config.input_shapes["observation.image"][0],
out_channels=config.image_encoder_hidden_dim,
kernel_size=7,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=5,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
)
self.aggregation_size += config.latent_dim * self.camera_number self.aggregation_size += config.latent_dim * self.camera_number
if config.freeze_vision_encoder:
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
if "observation.state" in config.input_shapes: if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential( self.state_enc_layers = nn.Sequential(
nn.Linear( nn.Linear(
@ -539,6 +510,8 @@ class SACObservationEncoder(nn.Module):
) )
self.aggregation_size += config.latent_dim self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
if "observation.environment_state" in config.input_shapes: if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential( self.env_state_enc_layers = nn.Sequential(
nn.Linear( nn.Linear(
@ -548,26 +521,11 @@ class SACObservationEncoder(nn.Module):
nn.LayerNorm(normalized_shape=config.latent_dim), nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(), nn.Tanh(),
) )
self.aggregation_size += config.latent_dim self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def _load_pretrained_vision_encoder(self):
"""Set up CNN encoder"""
from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained(self.config.vision_encoder_name)
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"):
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
else:
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
return self.image_enc_layers, self.image_enc_out_shape
def freeze_encoder(self):
"""Freeze all parameters in the encoder"""
for param in self.image_enc_layers.parameters():
param.requires_grad = False
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector. """Encode the image and/or state vector.
@ -579,12 +537,10 @@ class SACObservationEncoder(nn.Module):
# Concatenate all images along the channel dimension. # Concatenate all images along the channel dimension.
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")] image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
for image_key in image_keys: for image_key in image_keys:
if self.has_pretrained_vision_encoder: enc_feat = self.image_enc_layers(obs_dict[image_key])
enc_feat = self.image_enc_layers(obs_dict[image_key]).pooler_output
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
else:
enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
# if not self.has_pretrained_vision_encoder:
# enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
feat.append(enc_feat) feat.append(enc_feat)
if "observation.environment_state" in self.config.input_shapes: if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
@ -602,10 +558,107 @@ class SACObservationEncoder(nn.Module):
return self.config.latent_dim return self.config.latent_dim
class DefaultImageEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
in_channels=config.input_shapes["observation.image"][0],
out_channels=config.image_encoder_hidden_dim,
kernel_size=7,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=5,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
)
def forward(self, x):
return self.image_enc_layers(x)
class PretrainedImageEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
def _load_pretrained_vision_encoder(self, config):
"""Set up CNN encoder"""
from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name)
# self.image_enc_layers.pooler = Identity()
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"):
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
else:
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
return self.image_enc_layers, self.image_enc_out_shape
def forward(self, x):
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
# doesn't reach the classifier layer because we don't need it
enc_feat = self.image_enc_layers(x).pooler_output
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
return enc_feat
def freeze_image_encoder(image_encoder: nn.Module):
"""Freeze all parameters in the encoder"""
for param in image_encoder.parameters():
param.requires_grad = False
def orthogonal_init(): def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
# TODO (azouitine): I think in our case this function is not usefull we should remove it # TODO (azouitine): I think in our case this function is not usefull we should remove it
# after some investigation # after some investigation
# borrowed from tdmpc # borrowed from tdmpc
@ -626,3 +679,54 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
inp = torch.flatten(image_tensor, end_dim=-4) inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp) flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
if __name__ == "__main__":
# Test the SACObservationEncoder
import time
config = SACConfig()
config.num_critics = 10
encoder = SACObservationEncoder(config)
actor_encoder = SACObservationEncoder(config)
encoder = torch.compile(encoder)
critic_ensemble = CriticEnsemble(
encoder=encoder,
network_list=nn.ModuleList(
[
MLP(
input_dim=encoder.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
),
)
actor = Policy(
encoder=actor_encoder,
network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
)
encoder = encoder.to("cuda:0")
critic_ensemble = torch.compile(critic_ensemble)
critic_ensemble = critic_ensemble.to("cuda:0")
actor = torch.compile(actor)
actor = actor.to("cuda:0")
obs_dict = {
"observation.image": torch.randn(1, 3, 84, 84),
"observation.state": torch.randn(1, 4),
}
actions = torch.randn(1, 2).to("cuda:0")
obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
print("compiling...")
# q_value = critic_ensemble(obs_dict, actions)
action = actor(obs_dict)
print("compiled")
start = time.perf_counter()
for _ in range(1000):
# features = encoder(obs_dict)
action = actor(obs_dict)
# q_value = critic_ensemble(obs_dict, actions)
print("Time taken:", time.perf_counter() - start)

View File

@ -52,6 +52,8 @@ policy:
n_action_steps: 1 n_action_steps: 1
shared_encoder: true shared_encoder: true
# vision_encoder_name: null
freeze_vision_encoder: false
input_shapes: input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? # # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"] observation.state: ["${env.state_dim}"]

View File

@ -191,6 +191,7 @@ def act_with_policy(cfg: DictConfig):
# pretrained_policy_name_or_path=None, # pretrained_policy_name_or_path=None,
# device=device, # device=device,
# ) # )
policy = torch.compile(policy)
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
# HACK for maniskill # HACK for maniskill
@ -237,7 +238,9 @@ def act_with_policy(cfg: DictConfig):
logging.debug("[ACTOR] Load new parameters from Learner.") logging.debug("[ACTOR] Load new parameters from Learner.")
state_dict = parameters_queue.get() state_dict = parameters_queue.get()
state_dict = move_state_dict_to_device(state_dict, device=device) state_dict = move_state_dict_to_device(state_dict, device=device)
policy.actor.load_state_dict(state_dict) # strict=False for the case when the image encoder is frozen and not sent through
# the network. Becareful might cause issues if the wrong keys are passed
policy.actor.load_state_dict(state_dict, strict=False)
if len(list_transition_to_send_to_learner) > 0: if len(list_transition_to_send_to_learner) > 0:
logging.debug( logging.debug(

View File

@ -259,6 +259,9 @@ def learner_push_parameters(
while True: while True:
with policy_lock: with policy_lock:
params_dict = policy.actor.state_dict() params_dict = policy.actor.state_dict()
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
params_dict = {k: v for k, v in params_dict if not k.startswith("encoder.")}
params_dict = move_state_dict_to_device(params_dict, device="cpu") params_dict = move_state_dict_to_device(params_dict, device="cpu")
# Serialize # Serialize
buf = io.BytesIO() buf = io.BytesIO()
@ -541,6 +544,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
dataset_stats=None, dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
) )
# compile policy
policy = torch.compile(policy)
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy) optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)

View File

@ -13,26 +13,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import functools import functools
from pprint import pformat import logging
import random import random
from typing import Optional, Sequence, TypedDict, Callable from pprint import pformat
from typing import Callable, Optional, Sequence, TypedDict
import hydra import hydra
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from deepdiff import DeepDiff from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from torch import nn
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from tqdm import tqdm
# TODO: Remove the import of maniskill # TODO: Remove the import of maniskill
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.envs.factory import make_env, make_maniskill_env from lerobot.common.envs.factory import make_env, make_maniskill_env
from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation
from lerobot.common.logger import Logger, log_output_dir from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy