- Refactor observation encoder in `modeling_sac.py`
- added `torch.compile` to the actor and learner servers. - organized imports in `train_sac.py` - optimized the parameters push by not sending the frozen pre-trained encoder. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
f1c8bfe01e
commit
506821c7df
|
@ -55,9 +55,10 @@ class SACConfig:
|
|||
)
|
||||
camera_number: int = 1
|
||||
# Add type annotations for these fields:
|
||||
vision_encoder_name: str = field(default="microsoft/resnet-18")
|
||||
vision_encoder_name: str | None = field(default="microsoft/resnet-18")
|
||||
freeze_vision_encoder: bool = True
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = False
|
||||
shared_encoder: bool = True
|
||||
discount: float = 0.99
|
||||
temperature_init: float = 1.0
|
||||
num_critics: int = 2
|
||||
|
|
|
@ -312,7 +312,7 @@ class CriticEnsemble(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network_list: nn.Module,
|
||||
network_list: nn.ModuleList,
|
||||
init_final: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -320,6 +320,12 @@ class CriticEnsemble(nn.Module):
|
|||
self.network_list = network_list
|
||||
self.init_final = init_final
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
# Handle the case where a part of the encoder if frozen
|
||||
if self.encoder is not None:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
|
||||
|
||||
self.parameters_to_optimize += list(self.network_list.parameters())
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network_list[0].net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
|
@ -342,6 +348,7 @@ class CriticEnsemble(nn.Module):
|
|||
self.output_layers.append(output_layer)
|
||||
|
||||
self.output_layers = nn.ModuleList(self.output_layers)
|
||||
self.parameters_to_optimize += list(self.output_layers.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -474,61 +481,25 @@ class SACObservationEncoder(nn.Module):
|
|||
super().__init__()
|
||||
self.config = config
|
||||
self.has_pretrained_vision_encoder = False
|
||||
self.parameters_to_optimize = []
|
||||
|
||||
self.aggregation_size: int = 0
|
||||
if "observation.image" in config.input_shapes:
|
||||
self.camera_number = config.camera_number
|
||||
self.aggregation_size: int = 0
|
||||
|
||||
if self.config.vision_encoder_name is not None:
|
||||
self.image_enc_layers = PretrainedImageEncoder(config)
|
||||
self.has_pretrained_vision_encoder = True
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder()
|
||||
self.freeze_encoder()
|
||||
self.image_enc_proj = nn.Sequential(
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
else:
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=config.input_shapes["observation.image"][0],
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=5,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
with torch.inference_mode():
|
||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
)
|
||||
self.image_enc_layers = DefaultImageEncoder(config)
|
||||
|
||||
self.aggregation_size += config.latent_dim * self.camera_number
|
||||
|
||||
if config.freeze_vision_encoder:
|
||||
freeze_image_encoder(self.image_enc_layers)
|
||||
else:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
|
@ -539,6 +510,8 @@ class SACObservationEncoder(nn.Module):
|
|||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
|
||||
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
||||
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
|
@ -548,26 +521,11 @@ class SACObservationEncoder(nn.Module):
|
|||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
|
||||
def _load_pretrained_vision_encoder(self):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
self.image_enc_layers = AutoModel.from_pretrained(self.config.vision_encoder_name)
|
||||
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
|
||||
elif hasattr(self.image_enc_layers, "fc"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
|
||||
else:
|
||||
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
|
||||
return self.image_enc_layers, self.image_enc_out_shape
|
||||
|
||||
def freeze_encoder(self):
|
||||
"""Freeze all parameters in the encoder"""
|
||||
for param in self.image_enc_layers.parameters():
|
||||
param.requires_grad = False
|
||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
|
@ -579,12 +537,10 @@ class SACObservationEncoder(nn.Module):
|
|||
# Concatenate all images along the channel dimension.
|
||||
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
|
||||
for image_key in image_keys:
|
||||
if self.has_pretrained_vision_encoder:
|
||||
enc_feat = self.image_enc_layers(obs_dict[image_key]).pooler_output
|
||||
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
|
||||
else:
|
||||
enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
|
||||
enc_feat = self.image_enc_layers(obs_dict[image_key])
|
||||
|
||||
# if not self.has_pretrained_vision_encoder:
|
||||
# enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
|
||||
feat.append(enc_feat)
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
|
@ -602,10 +558,107 @@ class SACObservationEncoder(nn.Module):
|
|||
return self.config.latent_dim
|
||||
|
||||
|
||||
class DefaultImageEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=config.input_shapes["observation.image"][0],
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=5,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
with torch.inference_mode():
|
||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.image_enc_layers(x)
|
||||
|
||||
|
||||
class PretrainedImageEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
self.image_enc_proj = nn.Sequential(
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def _load_pretrained_vision_encoder(self, config):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name)
|
||||
# self.image_enc_layers.pooler = Identity()
|
||||
|
||||
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
|
||||
elif hasattr(self.image_enc_layers, "fc"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
|
||||
else:
|
||||
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
|
||||
return self.image_enc_layers, self.image_enc_out_shape
|
||||
|
||||
def forward(self, x):
|
||||
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
|
||||
# doesn't reach the classifier layer because we don't need it
|
||||
enc_feat = self.image_enc_layers(x).pooler_output
|
||||
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
|
||||
return enc_feat
|
||||
|
||||
|
||||
def freeze_image_encoder(image_encoder: nn.Module):
|
||||
"""Freeze all parameters in the encoder"""
|
||||
for param in image_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
def orthogonal_init():
|
||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
# TODO (azouitine): I think in our case this function is not usefull we should remove it
|
||||
# after some investigation
|
||||
# borrowed from tdmpc
|
||||
|
@ -626,3 +679,54 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
|
|||
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||
flat_out = fn(inp)
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the SACObservationEncoder
|
||||
import time
|
||||
|
||||
config = SACConfig()
|
||||
config.num_critics = 10
|
||||
encoder = SACObservationEncoder(config)
|
||||
actor_encoder = SACObservationEncoder(config)
|
||||
encoder = torch.compile(encoder)
|
||||
critic_ensemble = CriticEnsemble(
|
||||
encoder=encoder,
|
||||
network_list=nn.ModuleList(
|
||||
[
|
||||
MLP(
|
||||
input_dim=encoder.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
),
|
||||
)
|
||||
actor = Policy(
|
||||
encoder=actor_encoder,
|
||||
network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
encoder = encoder.to("cuda:0")
|
||||
critic_ensemble = torch.compile(critic_ensemble)
|
||||
critic_ensemble = critic_ensemble.to("cuda:0")
|
||||
actor = torch.compile(actor)
|
||||
actor = actor.to("cuda:0")
|
||||
obs_dict = {
|
||||
"observation.image": torch.randn(1, 3, 84, 84),
|
||||
"observation.state": torch.randn(1, 4),
|
||||
}
|
||||
actions = torch.randn(1, 2).to("cuda:0")
|
||||
obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
|
||||
print("compiling...")
|
||||
# q_value = critic_ensemble(obs_dict, actions)
|
||||
action = actor(obs_dict)
|
||||
print("compiled")
|
||||
start = time.perf_counter()
|
||||
for _ in range(1000):
|
||||
# features = encoder(obs_dict)
|
||||
action = actor(obs_dict)
|
||||
# q_value = critic_ensemble(obs_dict, actions)
|
||||
print("Time taken:", time.perf_counter() - start)
|
||||
|
|
|
@ -52,6 +52,8 @@ policy:
|
|||
n_action_steps: 1
|
||||
|
||||
shared_encoder: true
|
||||
# vision_encoder_name: null
|
||||
freeze_vision_encoder: false
|
||||
input_shapes:
|
||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.state: ["${env.state_dim}"]
|
||||
|
|
|
@ -191,6 +191,7 @@ def act_with_policy(cfg: DictConfig):
|
|||
# pretrained_policy_name_or_path=None,
|
||||
# device=device,
|
||||
# )
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
# HACK for maniskill
|
||||
|
@ -237,7 +238,9 @@ def act_with_policy(cfg: DictConfig):
|
|||
logging.debug("[ACTOR] Load new parameters from Learner.")
|
||||
state_dict = parameters_queue.get()
|
||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||
policy.actor.load_state_dict(state_dict)
|
||||
# strict=False for the case when the image encoder is frozen and not sent through
|
||||
# the network. Becareful might cause issues if the wrong keys are passed
|
||||
policy.actor.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
logging.debug(
|
||||
|
|
|
@ -259,6 +259,9 @@ def learner_push_parameters(
|
|||
while True:
|
||||
with policy_lock:
|
||||
params_dict = policy.actor.state_dict()
|
||||
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
|
||||
params_dict = {k: v for k, v in params_dict if not k.startswith("encoder.")}
|
||||
|
||||
params_dict = move_state_dict_to_device(params_dict, device="cpu")
|
||||
# Serialize
|
||||
buf = io.BytesIO()
|
||||
|
@ -541,6 +544,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
)
|
||||
# compile policy
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
|
|
|
@ -13,26 +13,25 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import functools
|
||||
from pprint import pformat
|
||||
import logging
|
||||
import random
|
||||
from typing import Optional, Sequence, TypedDict, Callable
|
||||
from pprint import pformat
|
||||
from typing import Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.envs.factory import make_env, make_maniskill_env
|
||||
from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation
|
||||
from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
|
Loading…
Reference in New Issue