diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 4baf7d88..8ea00a1b 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -18,7 +18,8 @@ # TODO: (1) better device management from copy import deepcopy -from typing import Callable, Optional, Tuple +from typing import Callable, Optional, Tuple, Union, Dict +from pathlib import Path import einops import numpy as np @@ -142,6 +143,131 @@ class SACPolicy( self.log_alpha = nn.Parameter(torch.tensor([0.0])) self.temperature = self.log_alpha.exp().item() + def _save_pretrained(self, save_directory): + """Custom save method to handle TensorDict properly""" + import os + import json + from dataclasses import asdict + from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME + from safetensors.torch import save_file + + # NOTE: Using tensordict.from_modules in the model to batch the inference using torch.vmap + # implies one side effect: the __batch_size parameters are saved in the state_dict + # __batch_size is torch.Size or safetensor save only torch.Tensor + # so we need to filter them out before saving + simplified_state_dict = {} + + for name, param in self.named_parameters(): + simplified_state_dict[name] = param + save_file( + simplified_state_dict, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE) + ) + + # Save config + config_dict = asdict(self.config) + with open(os.path.join(save_directory, CONFIG_NAME), "w") as f: + json.dump(config_dict, f, indent=2) + print(f"Saved config to {os.path.join(save_directory, CONFIG_NAME)}") + + @classmethod + def _from_pretrained( + cls, + *, + model_id: str, + revision: Optional[str], + cache_dir: Optional[Union[str, Path]], + force_download: bool, + proxies: Optional[Dict], + resume_download: Optional[bool], + local_files_only: bool, + token: Optional[Union[str, bool]], + map_location: str = "cpu", + strict: bool = False, + **model_kwargs, + ) -> "SACPolicy": + """Custom load method to handle loading SAC policy from saved files""" + import os + import json + from pathlib import Path + from huggingface_hub import hf_hub_download + from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME + from safetensors.torch import load_file + from lerobot.common.policies.sac.configuration_sac import SACConfig + + # Check if model_id is a local path or a hub model ID + if os.path.isdir(model_id): + model_path = Path(model_id) + safetensors_file = os.path.join(model_path, SAFETENSORS_SINGLE_FILE) + config_file = os.path.join(model_path, CONFIG_NAME) + else: + # Download the safetensors file from the hub + safetensors_file = hf_hub_download( + repo_id=model_id, + filename=SAFETENSORS_SINGLE_FILE, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + # Download the config file + try: + config_file = hf_hub_download( + repo_id=model_id, + filename=CONFIG_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + except Exception: + config_file = None + + # Load or create config + if config_file and os.path.exists(config_file): + # Load config from file + with open(config_file, "r") as f: + config_dict = json.load(f) + config = SACConfig(**config_dict) + else: + # Use the provided config or create a default one + config = model_kwargs.get("config", SACConfig()) + + # Create a new instance with the loaded config + model = cls(config=config) + + # Load state dict from safetensors file + if os.path.exists(safetensors_file): + # Note: The load_file function returns a dict with the parameters, but __batch_size + # is not loaded so we need to copy it from the model state_dict + # Load the parameters only + loaded_state_dict = load_file(safetensors_file, device=map_location) + + # Copy batch size parameters + find_and_copy_params( + original_state_dict=model.state_dict(), + loaded_state_dict=loaded_state_dict, + pattern="__batch_size", + match_type="endswith", + ) + + # Copy normalization buffer parameters + find_and_copy_params( + original_state_dict=model.state_dict(), + loaded_state_dict=loaded_state_dict, + pattern="_orig_mod.output_normalization.buffer_action", + match_type="contains", + ) + + model.load_state_dict(loaded_state_dict, strict=False) + + return model + def reset(self): """Reset the policy""" pass @@ -276,6 +402,9 @@ class SACPolicy( actions_pi, log_probs, _ = self.actor(observations, observation_features) + # TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way + actions_pi = self.unnormalize_outputs({"action": actions_pi})["action"] + q_preds = self.critic_forward( observations, actions_pi, @@ -334,6 +463,50 @@ class MLP(nn.Module): return self.net(x) +def find_and_copy_params( + original_state_dict: dict[str, torch.Tensor], + loaded_state_dict: dict[str, torch.Tensor], + pattern: str, + match_type: str = "contains", +) -> list[str]: + """Find and copy parameters from original state dict to loaded state dict based on a pattern. + + This function can search for keys in different ways based on the match_type: + - "exact": The key must exactly match the pattern + - "contains": The key must contain the pattern anywhere + - "startswith": The key must start with the pattern + - "endswith": The key must end with the pattern + + Args: + original_state_dict: The source state dictionary + loaded_state_dict: The target state dictionary + pattern: The pattern to search for in keys + match_type: How to match the pattern (exact, contains, startswith, endswith) + + Returns: + list[str]: List of keys that were copied + """ + copied_keys = [] + + for key in original_state_dict: + should_copy = False + + if match_type == "exact": + should_copy = key == pattern + elif match_type == "contains": + should_copy = pattern in key + elif match_type == "startswith": + should_copy = key.startswith(pattern) + elif match_type == "endswith": + should_copy = key.endswith(pattern) + + if should_copy: + loaded_state_dict[key] = original_state_dict[key] + copied_keys.append(key) + + return copied_keys + + class CriticHead(nn.Module): def __init__( self,