diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 2c4bad5f..646da874 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -18,7 +18,7 @@ # TODO: (1) better device management from copy import deepcopy -from typing import Callable, Optional, Tuple, Union, Dict +from typing import Callable, Optional, Tuple, Union, Dict, List from pathlib import Path import einops @@ -88,33 +88,33 @@ class SACPolicy( encoder_critic = SACObservationEncoder(config, self.normalize_inputs) encoder_actor = SACObservationEncoder(config, self.normalize_inputs) + # Create a list of critic heads + critic_heads = [ + CriticHead( + input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], + **config.critic_network_kwargs, + ) + for _ in range(config.num_critics) + ] + self.critic_ensemble = CriticEnsemble( encoder=encoder_critic, - ensemble=Ensemble( - [ - CriticHead( - input_dim=encoder_critic.output_dim - + config.output_shapes["action"][0], - **config.critic_network_kwargs, - ) - for _ in range(config.num_critics) - ] - ), + ensemble=critic_heads, output_normalization=self.normalize_targets, ) + # Create target critic heads as deepcopies of the original critic heads + target_critic_heads = [ + CriticHead( + input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], + **config.critic_network_kwargs, + ) + for _ in range(config.num_critics) + ] + self.critic_target = CriticEnsemble( encoder=encoder_critic, - ensemble=Ensemble( - [ - CriticHead( - input_dim=encoder_critic.output_dim - + config.output_shapes["action"][0], - **config.critic_network_kwargs, - ) - for _ in range(config.num_critics) - ] - ), + ensemble=target_critic_heads, output_normalization=self.normalize_targets, ) @@ -149,19 +149,9 @@ class SACPolicy( import json from dataclasses import asdict from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME - from safetensors.torch import save_file + from safetensors.torch import save_model - # NOTE: Using tensordict.from_modules in the model to batch the inference using torch.vmap - # implies one side effect: the __batch_size parameters are saved in the state_dict - # __batch_size is torch.Size or safetensor save only torch.Tensor - # so we need to filter them out before saving - simplified_state_dict = {} - - for name, param in self.named_parameters(): - simplified_state_dict[name] = param - save_file( - simplified_state_dict, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE) - ) + save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE)) # Save config config_dict = asdict(self.config) @@ -191,7 +181,7 @@ class SACPolicy( from pathlib import Path from huggingface_hub import hf_hub_download from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME - from safetensors.torch import load_file + from safetensors.torch import load_model from lerobot.common.policies.sac.configuration_sac import SACConfig # Check if model_id is a local path or a hub model ID @@ -243,28 +233,7 @@ class SACPolicy( # Load state dict from safetensors file if os.path.exists(safetensors_file): - # Note: The load_file function returns a dict with the parameters, but __batch_size - # is not loaded so we need to copy it from the model state_dict - # Load the parameters only - loaded_state_dict = load_file(safetensors_file, device=map_location) - - # Copy batch size parameters - find_and_copy_params( - original_state_dict=model.state_dict(), - loaded_state_dict=loaded_state_dict, - pattern="__batch_size", - match_type="endswith", - ) - - # Copy normalization buffer parameters - find_and_copy_params( - original_state_dict=model.state_dict(), - loaded_state_dict=loaded_state_dict, - pattern="_orig_mod.output_normalization.buffer_action", - match_type="contains", - ) - - model.load_state_dict(loaded_state_dict, strict=False) + load_model(model, filename=safetensors_file, device=map_location) return model @@ -594,21 +563,21 @@ class CriticEnsemble(nn.Module): def __init__( self, encoder: Optional[nn.Module], - ensemble: "Ensemble[CriticHead]", + ensemble: List[CriticHead], output_normalization: nn.Module, init_final: Optional[float] = None, ): super().__init__() self.encoder = encoder - self.ensemble = ensemble self.init_final = init_final self.output_normalization = output_normalization + self.critics = nn.ModuleList(ensemble) self.parameters_to_optimize = [] # Handle the case where a part of the encoder if frozen if self.encoder is not None: self.parameters_to_optimize += list(self.encoder.parameters_to_optimize) - self.parameters_to_optimize += list(self.ensemble.parameters()) + self.parameters_to_optimize += list(self.critics.parameters()) def forward( self, @@ -632,8 +601,15 @@ class CriticEnsemble(nn.Module): ) inputs = torch.cat([obs_enc, actions], dim=-1) - q_values = self.ensemble(inputs) # [num_critics, B, 1] - return q_values.squeeze(-1) # [num_critics, B] + + # Loop through critics and collect outputs + q_values = [] + for critic in self.critics: + q_values.append(critic(inputs)) + + # Stack outputs to match expected shape [num_critics, batch_size] + q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0) + return q_values class Policy(nn.Module): @@ -706,9 +682,9 @@ class Policy(nn.Module): # Compute standard deviations if self.fixed_std is None: log_std = self.std_layer(outputs) - assert not torch.isnan( - log_std - ).any(), "[ERROR] log_std became NaN after std_layer!" + assert not torch.isnan(log_std).any(), ( + "[ERROR] log_std became NaN after std_layer!" + ) if self.use_tanh_squash: log_std = torch.tanh(log_std) @@ -1025,73 +1001,78 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict: if __name__ == "__main__": - # Test the SACObservationEncoder + # Benchmark the CriticEnsemble performance import time - config = SACConfig() - config.num_critics = 10 - config.vision_encoder_name = None - encoder = SACObservationEncoder(config, nn.Identity()) - # actor_encoder = SACObservationEncoder(config) - # encoder = torch.compile(encoder) + # Configuration + num_critics = 10 + batch_size = 32 + action_dim = 7 + obs_dim = 64 + hidden_dims = [256, 256] + num_iterations = 100 + + print("Creating test environment...") + + # Create a simple dummy encoder + class DummyEncoder(nn.Module): + def __init__(self): + super().__init__() + self.output_dim = obs_dim + self.parameters_to_optimize = [] + + def forward(self, obs): + # Just return a random tensor of the right shape + # In practice, this would encode the observations + return torch.randn(batch_size, obs_dim, device=device) + + # Create critic heads + print(f"Creating {num_critics} critic heads...") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + critic_heads = [ + CriticHead( + input_dim=obs_dim + action_dim, + hidden_dims=hidden_dims, + ).to(device) + for _ in range(num_critics) + ] + + # Create the critic ensemble + print("Creating CriticEnsemble...") critic_ensemble = CriticEnsemble( - encoder=encoder, - ensemble=Ensemble( - [ - CriticHead( - input_dim=encoder.output_dim + config.output_shapes["action"][0], - **config.critic_network_kwargs, - ) - for _ in range(config.num_critics) - ] - ), + encoder=DummyEncoder().to(device), + ensemble=critic_heads, output_normalization=nn.Identity(), - ) - # actor = Policy( - # encoder=actor_encoder, - # network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs), - # action_dim=config.output_shapes["action"][0], - # encoder_is_shared=config.shared_encoder, - # **config.policy_kwargs, - # ) - # encoder = encoder.to("cuda:0") - # critic_ensemble = torch.compile(critic_ensemble) - critic_ensemble = critic_ensemble.to("cuda:0") - # actor = torch.compile(actor) - # actor = actor.to("cuda:0") + ).to(device) + + # Create random input data + print("Creating input data...") obs_dict = { - "observation.image": torch.randn(8, 3, 84, 84), - "observation.state": torch.randn(8, 4), + "observation.state": torch.randn(batch_size, obs_dim, device=device), } - actions = torch.randn(8, 2).to("cuda:0") - # obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()} - # print("compiling...") - q_value = critic_ensemble(obs_dict, actions) - print(q_value.size()) - # action = actor(obs_dict) - # print("compiled") - # start = time.perf_counter() - # for _ in range(1000): - # # features = encoder(obs_dict) - # action = actor(obs_dict) - # # q_value = critic_ensemble(obs_dict, actions) - # print("Time taken:", time.perf_counter() - start) - # Compare the performance of the ensemble vs a for loop of 16 MLPs - ensemble = Ensemble([CriticHead(256, [256, 256]) for _ in range(2)]) - ensemble = ensemble.to("cuda:0") - critic = CriticHead(256, [256, 256]) - critic = critic.to("cuda:0") - data_ensemble = torch.randn(8, 256).to("cuda:0") - ensemble = torch.compile(ensemble) - # critic = torch.compile(critic) - print(ensemble(data_ensemble).size()) - print(critic(data_ensemble).size()) - start = time.perf_counter() - for _ in range(1000): - ensemble(data_ensemble) - print("Time taken:", time.perf_counter() - start) - start = time.perf_counter() - for _ in range(1000): - for i in range(2): - critic(data_ensemble) - print("Time taken:", time.perf_counter() - start) + actions = torch.randn(batch_size, action_dim, device=device) + + # Warmup run + print("Warming up...") + _ = critic_ensemble(obs_dict, actions) + + # Time the forward pass + print(f"Running benchmark with {num_iterations} iterations...") + start_time = time.perf_counter() + for _ in range(num_iterations): + q_values = critic_ensemble(obs_dict, actions) + end_time = time.perf_counter() + + # Print results + elapsed_time = end_time - start_time + print(f"Total time: {elapsed_time:.4f} seconds") + print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms") + print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size] + + # Verify that all critic heads produce different outputs + # This confirms each critic head is unique + # print("\nVerifying critic outputs are different:") + # for i in range(num_critics): + # for j in range(i + 1, num_critics): + # diff = torch.abs(q_values[i] - q_values[j]).mean().item() + # print(f"Mean difference between critic {i} and {j}: {diff:.6f}") diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index eb04effd..d1235980 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -121,7 +121,7 @@ def load_training_state( return None, None training_state = torch.load( - logger.last_checkpoint_dir / logger.training_state_file_name + logger.last_checkpoint_dir / logger.training_state_file_name, weights_only=False ) if isinstance(training_state["optimizer"], dict): @@ -160,6 +160,7 @@ def initialize_replay_buffer( optimize_memory=True, ) + logging.info("Resume training load the online dataset") dataset = LeRobotDataset( repo_id=cfg.dataset_repo_id, local_files_only=True, @@ -174,6 +175,37 @@ def initialize_replay_buffer( ) +def initialize_offline_replay_buffer( + cfg: DictConfig, + logger: Logger, + device: str, + storage_device: str, + active_action_dims: list[int] | None = None, +) -> ReplayBuffer: + if not cfg.resume: + logging.info("make_dataset offline buffer") + offline_dataset = make_dataset(cfg) + if cfg.resume: + logging.info("load offline dataset") + offline_dataset = LeRobotDataset( + repo_id=cfg.dataset_repo_id, + local_files_only=True, + root=logger.log_dir / "dataset_offline", + ) + + logging.info("Convert to a offline replay buffer") + offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( + offline_dataset, + device=device, + state_keys=cfg.policy.input_shapes.keys(), + action_mask=active_action_dims, + action_delta=cfg.env.wrapper.delta_action, + storage_device=storage_device, + optimize_memory=True, + ) + return offline_replay_buffer + + def get_observation_features( policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor ) -> tuple[torch.Tensor | None, torch.Tensor | None]: @@ -447,9 +479,6 @@ def add_actor_information_and_train( offline_replay_buffer = None if cfg.dataset_repo_id is not None: - logging.info("make_dataset offline buffer") - offline_dataset = make_dataset(cfg) - logging.info("Convertion to a offline replay buffer") active_action_dims = None if cfg.env.wrapper.joint_masking_action_space is not None: active_action_dims = [ @@ -457,14 +486,12 @@ def add_actor_information_and_train( for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask ] - offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( - offline_dataset, + offline_replay_buffer = initialize_offline_replay_buffer( + cfg=cfg, + logger=logger, device=device, - state_keys=cfg.policy.input_shapes.keys(), - action_mask=active_action_dims, - action_delta=cfg.env.wrapper.delta_action, storage_device=storage_device, - optimize_memory=True, + active_action_dims=active_action_dims, ) batch_size: int = batch_size // 2 # We will sample from both replay buffer @@ -714,6 +741,19 @@ def add_actor_information_and_train( replay_buffer.to_lerobot_dataset( cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset" ) + if offline_replay_buffer is not None: + dataset_dir = logger.log_dir / "dataset_offline" + + if dataset_dir.exists() and dataset_dir.is_dir(): + shutil.rmtree( + dataset_dir, + ) + + offline_replay_buffer.to_lerobot_dataset( + cfg.dataset_repo_id, + fps=cfg.fps, + root=logger.log_dir / "dataset_offline", + ) logging.info("Resume training")