- Refactor observation encoder in `modeling_sac.py`

- added `torch.compile` to the actor and learner servers.
- organized imports in `train_sac.py`
- optimized the parameters push by not sending the frozen pre-trained encoder.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-01-31 16:45:52 +00:00
parent f1c8bfe01e
commit 506821c7df
6 changed files with 199 additions and 85 deletions

View File

@ -55,9 +55,10 @@ class SACConfig:
)
camera_number: int = 1
# Add type annotations for these fields:
vision_encoder_name: str = field(default="microsoft/resnet-18")
vision_encoder_name: str | None = field(default="microsoft/resnet-18")
freeze_vision_encoder: bool = True
image_encoder_hidden_dim: int = 32
shared_encoder: bool = False
shared_encoder: bool = True
discount: float = 0.99
temperature_init: float = 1.0
num_critics: int = 2

View File

@ -312,7 +312,7 @@ class CriticEnsemble(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
network_list: nn.Module,
network_list: nn.ModuleList,
init_final: Optional[float] = None,
):
super().__init__()
@ -320,6 +320,12 @@ class CriticEnsemble(nn.Module):
self.network_list = network_list
self.init_final = init_final
self.parameters_to_optimize = []
# Handle the case where a part of the encoder if frozen
if self.encoder is not None:
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
self.parameters_to_optimize += list(self.network_list.parameters())
# Find the last Linear layer's output dimension
for layer in reversed(network_list[0].net):
if isinstance(layer, nn.Linear):
@ -342,6 +348,7 @@ class CriticEnsemble(nn.Module):
self.output_layers.append(output_layer)
self.output_layers = nn.ModuleList(self.output_layers)
self.parameters_to_optimize += list(self.output_layers.parameters())
def forward(
self,
@ -474,61 +481,25 @@ class SACObservationEncoder(nn.Module):
super().__init__()
self.config = config
self.has_pretrained_vision_encoder = False
self.parameters_to_optimize = []
self.aggregation_size: int = 0
if "observation.image" in config.input_shapes:
self.camera_number = config.camera_number
self.aggregation_size: int = 0
if self.config.vision_encoder_name is not None:
self.image_enc_layers = PretrainedImageEncoder(config)
self.has_pretrained_vision_encoder = True
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder()
self.freeze_encoder()
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
else:
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
in_channels=config.input_shapes["observation.image"][0],
out_channels=config.image_encoder_hidden_dim,
kernel_size=7,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=5,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
)
self.image_enc_layers = DefaultImageEncoder(config)
self.aggregation_size += config.latent_dim * self.camera_number
if config.freeze_vision_encoder:
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
nn.Linear(
@ -539,6 +510,8 @@ class SACObservationEncoder(nn.Module):
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
@ -548,26 +521,11 @@ class SACObservationEncoder(nn.Module):
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
def _load_pretrained_vision_encoder(self):
"""Set up CNN encoder"""
from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained(self.config.vision_encoder_name)
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"):
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
else:
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
return self.image_enc_layers, self.image_enc_out_shape
def freeze_encoder(self):
"""Freeze all parameters in the encoder"""
for param in self.image_enc_layers.parameters():
param.requires_grad = False
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
@ -579,12 +537,10 @@ class SACObservationEncoder(nn.Module):
# Concatenate all images along the channel dimension.
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
for image_key in image_keys:
if self.has_pretrained_vision_encoder:
enc_feat = self.image_enc_layers(obs_dict[image_key]).pooler_output
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
else:
enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
enc_feat = self.image_enc_layers(obs_dict[image_key])
# if not self.has_pretrained_vision_encoder:
# enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
feat.append(enc_feat)
if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
@ -602,10 +558,107 @@ class SACObservationEncoder(nn.Module):
return self.config.latent_dim
class DefaultImageEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
in_channels=config.input_shapes["observation.image"][0],
out_channels=config.image_encoder_hidden_dim,
kernel_size=7,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=5,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
)
def forward(self, x):
return self.image_enc_layers(x)
class PretrainedImageEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
def _load_pretrained_vision_encoder(self, config):
"""Set up CNN encoder"""
from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name)
# self.image_enc_layers.pooler = Identity()
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"):
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
else:
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
return self.image_enc_layers, self.image_enc_out_shape
def forward(self, x):
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
# doesn't reach the classifier layer because we don't need it
enc_feat = self.image_enc_layers(x).pooler_output
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
return enc_feat
def freeze_image_encoder(image_encoder: nn.Module):
"""Freeze all parameters in the encoder"""
for param in image_encoder.parameters():
param.requires_grad = False
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
# TODO (azouitine): I think in our case this function is not usefull we should remove it
# after some investigation
# borrowed from tdmpc
@ -626,3 +679,54 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
if __name__ == "__main__":
# Test the SACObservationEncoder
import time
config = SACConfig()
config.num_critics = 10
encoder = SACObservationEncoder(config)
actor_encoder = SACObservationEncoder(config)
encoder = torch.compile(encoder)
critic_ensemble = CriticEnsemble(
encoder=encoder,
network_list=nn.ModuleList(
[
MLP(
input_dim=encoder.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
),
)
actor = Policy(
encoder=actor_encoder,
network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
)
encoder = encoder.to("cuda:0")
critic_ensemble = torch.compile(critic_ensemble)
critic_ensemble = critic_ensemble.to("cuda:0")
actor = torch.compile(actor)
actor = actor.to("cuda:0")
obs_dict = {
"observation.image": torch.randn(1, 3, 84, 84),
"observation.state": torch.randn(1, 4),
}
actions = torch.randn(1, 2).to("cuda:0")
obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
print("compiling...")
# q_value = critic_ensemble(obs_dict, actions)
action = actor(obs_dict)
print("compiled")
start = time.perf_counter()
for _ in range(1000):
# features = encoder(obs_dict)
action = actor(obs_dict)
# q_value = critic_ensemble(obs_dict, actions)
print("Time taken:", time.perf_counter() - start)

View File

@ -52,6 +52,8 @@ policy:
n_action_steps: 1
shared_encoder: true
# vision_encoder_name: null
freeze_vision_encoder: false
input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"]

View File

@ -191,6 +191,7 @@ def act_with_policy(cfg: DictConfig):
# pretrained_policy_name_or_path=None,
# device=device,
# )
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
# HACK for maniskill
@ -237,7 +238,9 @@ def act_with_policy(cfg: DictConfig):
logging.debug("[ACTOR] Load new parameters from Learner.")
state_dict = parameters_queue.get()
state_dict = move_state_dict_to_device(state_dict, device=device)
policy.actor.load_state_dict(state_dict)
# strict=False for the case when the image encoder is frozen and not sent through
# the network. Becareful might cause issues if the wrong keys are passed
policy.actor.load_state_dict(state_dict, strict=False)
if len(list_transition_to_send_to_learner) > 0:
logging.debug(

View File

@ -259,6 +259,9 @@ def learner_push_parameters(
while True:
with policy_lock:
params_dict = policy.actor.state_dict()
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
params_dict = {k: v for k, v in params_dict if not k.startswith("encoder.")}
params_dict = move_state_dict_to_device(params_dict, device="cpu")
# Serialize
buf = io.BytesIO()
@ -541,6 +544,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
# compile policy
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)

View File

@ -13,26 +13,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import functools
from pprint import pformat
import logging
import random
from typing import Optional, Sequence, TypedDict, Callable
from pprint import pformat
from typing import Callable, Optional, Sequence, TypedDict
import hydra
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from torch import nn
from tqdm import tqdm
# TODO: Remove the import of maniskill
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.envs.factory import make_env, make_maniskill_env
from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation
from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy