From a3ef7dc6c37ecd3676eb279a26628e682cb2c907 Mon Sep 17 00:00:00 2001
From: AdilZouitine <adilzouitinegm@gmail.com>
Date: Wed, 12 Mar 2025 10:15:37 +0000
Subject: [PATCH] Add custom save and load methods for SAC policy

- Implement `_save_pretrained` method to handle TensorDict state saving
- Add `_from_pretrained` class method for loading SAC policy from files
- Create utility function `find_and_copy_params` to handle parameter copying
---
 lerobot/common/policies/sac/modeling_sac.py | 175 +++++++++++++++++++-
 1 file changed, 174 insertions(+), 1 deletion(-)

diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index 4baf7d88..8ea00a1b 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -18,7 +18,8 @@
 # TODO: (1) better device management
 
 from copy import deepcopy
-from typing import Callable, Optional, Tuple
+from typing import Callable, Optional, Tuple, Union, Dict
+from pathlib import Path
 
 import einops
 import numpy as np
@@ -142,6 +143,131 @@ class SACPolicy(
         self.log_alpha = nn.Parameter(torch.tensor([0.0]))
         self.temperature = self.log_alpha.exp().item()
 
+    def _save_pretrained(self, save_directory):
+        """Custom save method to handle TensorDict properly"""
+        import os
+        import json
+        from dataclasses import asdict
+        from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
+        from safetensors.torch import save_file
+
+        # NOTE: Using tensordict.from_modules in the model to batch the inference using torch.vmap
+        # implies one side effect: the __batch_size parameters are saved in the state_dict
+        # __batch_size is torch.Size or safetensor save only torch.Tensor
+        # so we need to filter them out before saving
+        simplified_state_dict = {}
+
+        for name, param in self.named_parameters():
+            simplified_state_dict[name] = param
+        save_file(
+            simplified_state_dict, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE)
+        )
+
+        # Save config
+        config_dict = asdict(self.config)
+        with open(os.path.join(save_directory, CONFIG_NAME), "w") as f:
+            json.dump(config_dict, f, indent=2)
+            print(f"Saved config to {os.path.join(save_directory, CONFIG_NAME)}")
+
+    @classmethod
+    def _from_pretrained(
+        cls,
+        *,
+        model_id: str,
+        revision: Optional[str],
+        cache_dir: Optional[Union[str, Path]],
+        force_download: bool,
+        proxies: Optional[Dict],
+        resume_download: Optional[bool],
+        local_files_only: bool,
+        token: Optional[Union[str, bool]],
+        map_location: str = "cpu",
+        strict: bool = False,
+        **model_kwargs,
+    ) -> "SACPolicy":
+        """Custom load method to handle loading SAC policy from saved files"""
+        import os
+        import json
+        from pathlib import Path
+        from huggingface_hub import hf_hub_download
+        from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
+        from safetensors.torch import load_file
+        from lerobot.common.policies.sac.configuration_sac import SACConfig
+
+        # Check if model_id is a local path or a hub model ID
+        if os.path.isdir(model_id):
+            model_path = Path(model_id)
+            safetensors_file = os.path.join(model_path, SAFETENSORS_SINGLE_FILE)
+            config_file = os.path.join(model_path, CONFIG_NAME)
+        else:
+            # Download the safetensors file from the hub
+            safetensors_file = hf_hub_download(
+                repo_id=model_id,
+                filename=SAFETENSORS_SINGLE_FILE,
+                revision=revision,
+                cache_dir=cache_dir,
+                force_download=force_download,
+                proxies=proxies,
+                resume_download=resume_download,
+                token=token,
+                local_files_only=local_files_only,
+            )
+            # Download the config file
+            try:
+                config_file = hf_hub_download(
+                    repo_id=model_id,
+                    filename=CONFIG_NAME,
+                    revision=revision,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    proxies=proxies,
+                    resume_download=resume_download,
+                    token=token,
+                    local_files_only=local_files_only,
+                )
+            except Exception:
+                config_file = None
+
+        # Load or create config
+        if config_file and os.path.exists(config_file):
+            # Load config from file
+            with open(config_file, "r") as f:
+                config_dict = json.load(f)
+            config = SACConfig(**config_dict)
+        else:
+            # Use the provided config or create a default one
+            config = model_kwargs.get("config", SACConfig())
+
+        # Create a new instance with the loaded config
+        model = cls(config=config)
+
+        # Load state dict from safetensors file
+        if os.path.exists(safetensors_file):
+            # Note: The load_file function returns a dict with the parameters, but __batch_size
+            # is not loaded so we need to copy it from the model state_dict
+            # Load the parameters only
+            loaded_state_dict = load_file(safetensors_file, device=map_location)
+
+            # Copy batch size parameters
+            find_and_copy_params(
+                original_state_dict=model.state_dict(),
+                loaded_state_dict=loaded_state_dict,
+                pattern="__batch_size",
+                match_type="endswith",
+            )
+
+            # Copy normalization buffer parameters
+            find_and_copy_params(
+                original_state_dict=model.state_dict(),
+                loaded_state_dict=loaded_state_dict,
+                pattern="_orig_mod.output_normalization.buffer_action",
+                match_type="contains",
+            )
+
+            model.load_state_dict(loaded_state_dict, strict=False)
+
+        return model
+
     def reset(self):
         """Reset the policy"""
         pass
@@ -276,6 +402,9 @@ class SACPolicy(
 
         actions_pi, log_probs, _ = self.actor(observations, observation_features)
 
+        # TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
+        actions_pi = self.unnormalize_outputs({"action": actions_pi})["action"]
+
         q_preds = self.critic_forward(
             observations,
             actions_pi,
@@ -334,6 +463,50 @@ class MLP(nn.Module):
         return self.net(x)
 
 
+def find_and_copy_params(
+    original_state_dict: dict[str, torch.Tensor],
+    loaded_state_dict: dict[str, torch.Tensor],
+    pattern: str,
+    match_type: str = "contains",
+) -> list[str]:
+    """Find and copy parameters from original state dict to loaded state dict based on a pattern.
+
+    This function can search for keys in different ways based on the match_type:
+    - "exact": The key must exactly match the pattern
+    - "contains": The key must contain the pattern anywhere
+    - "startswith": The key must start with the pattern
+    - "endswith": The key must end with the pattern
+
+    Args:
+        original_state_dict: The source state dictionary
+        loaded_state_dict: The target state dictionary
+        pattern: The pattern to search for in keys
+        match_type: How to match the pattern (exact, contains, startswith, endswith)
+
+    Returns:
+        list[str]: List of keys that were copied
+    """
+    copied_keys = []
+
+    for key in original_state_dict:
+        should_copy = False
+
+        if match_type == "exact":
+            should_copy = key == pattern
+        elif match_type == "contains":
+            should_copy = pattern in key
+        elif match_type == "startswith":
+            should_copy = key.startswith(pattern)
+        elif match_type == "endswith":
+            should_copy = key.endswith(pattern)
+
+        if should_copy:
+            loaded_state_dict[key] = original_state_dict[key]
+            copied_keys.append(key)
+
+    return copied_keys
+
+
 class CriticHead(nn.Module):
     def __init__(
         self,