Add custom save and load methods for SAC policy

- Implement `_save_pretrained` method to handle TensorDict state saving
- Add `_from_pretrained` class method for loading SAC policy from files
- Create utility function `find_and_copy_params` to handle parameter copying
This commit is contained in:
AdilZouitine 2025-03-12 10:15:37 +00:00
parent 7e3e1ce173
commit a3ef7dc6c3
1 changed files with 174 additions and 1 deletions

View File

@ -18,7 +18,8 @@
# TODO: (1) better device management # TODO: (1) better device management
from copy import deepcopy from copy import deepcopy
from typing import Callable, Optional, Tuple from typing import Callable, Optional, Tuple, Union, Dict
from pathlib import Path
import einops import einops
import numpy as np import numpy as np
@ -142,6 +143,131 @@ class SACPolicy(
self.log_alpha = nn.Parameter(torch.tensor([0.0])) self.log_alpha = nn.Parameter(torch.tensor([0.0]))
self.temperature = self.log_alpha.exp().item() self.temperature = self.log_alpha.exp().item()
def _save_pretrained(self, save_directory):
"""Custom save method to handle TensorDict properly"""
import os
import json
from dataclasses import asdict
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
from safetensors.torch import save_file
# NOTE: Using tensordict.from_modules in the model to batch the inference using torch.vmap
# implies one side effect: the __batch_size parameters are saved in the state_dict
# __batch_size is torch.Size or safetensor save only torch.Tensor
# so we need to filter them out before saving
simplified_state_dict = {}
for name, param in self.named_parameters():
simplified_state_dict[name] = param
save_file(
simplified_state_dict, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE)
)
# Save config
config_dict = asdict(self.config)
with open(os.path.join(save_directory, CONFIG_NAME), "w") as f:
json.dump(config_dict, f, indent=2)
print(f"Saved config to {os.path.join(save_directory, CONFIG_NAME)}")
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool,
proxies: Optional[Dict],
resume_download: Optional[bool],
local_files_only: bool,
token: Optional[Union[str, bool]],
map_location: str = "cpu",
strict: bool = False,
**model_kwargs,
) -> "SACPolicy":
"""Custom load method to handle loading SAC policy from saved files"""
import os
import json
from pathlib import Path
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
from safetensors.torch import load_file
from lerobot.common.policies.sac.configuration_sac import SACConfig
# Check if model_id is a local path or a hub model ID
if os.path.isdir(model_id):
model_path = Path(model_id)
safetensors_file = os.path.join(model_path, SAFETENSORS_SINGLE_FILE)
config_file = os.path.join(model_path, CONFIG_NAME)
else:
# Download the safetensors file from the hub
safetensors_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
# Download the config file
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except Exception:
config_file = None
# Load or create config
if config_file and os.path.exists(config_file):
# Load config from file
with open(config_file, "r") as f:
config_dict = json.load(f)
config = SACConfig(**config_dict)
else:
# Use the provided config or create a default one
config = model_kwargs.get("config", SACConfig())
# Create a new instance with the loaded config
model = cls(config=config)
# Load state dict from safetensors file
if os.path.exists(safetensors_file):
# Note: The load_file function returns a dict with the parameters, but __batch_size
# is not loaded so we need to copy it from the model state_dict
# Load the parameters only
loaded_state_dict = load_file(safetensors_file, device=map_location)
# Copy batch size parameters
find_and_copy_params(
original_state_dict=model.state_dict(),
loaded_state_dict=loaded_state_dict,
pattern="__batch_size",
match_type="endswith",
)
# Copy normalization buffer parameters
find_and_copy_params(
original_state_dict=model.state_dict(),
loaded_state_dict=loaded_state_dict,
pattern="_orig_mod.output_normalization.buffer_action",
match_type="contains",
)
model.load_state_dict(loaded_state_dict, strict=False)
return model
def reset(self): def reset(self):
"""Reset the policy""" """Reset the policy"""
pass pass
@ -276,6 +402,9 @@ class SACPolicy(
actions_pi, log_probs, _ = self.actor(observations, observation_features) actions_pi, log_probs, _ = self.actor(observations, observation_features)
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
actions_pi = self.unnormalize_outputs({"action": actions_pi})["action"]
q_preds = self.critic_forward( q_preds = self.critic_forward(
observations, observations,
actions_pi, actions_pi,
@ -334,6 +463,50 @@ class MLP(nn.Module):
return self.net(x) return self.net(x)
def find_and_copy_params(
original_state_dict: dict[str, torch.Tensor],
loaded_state_dict: dict[str, torch.Tensor],
pattern: str,
match_type: str = "contains",
) -> list[str]:
"""Find and copy parameters from original state dict to loaded state dict based on a pattern.
This function can search for keys in different ways based on the match_type:
- "exact": The key must exactly match the pattern
- "contains": The key must contain the pattern anywhere
- "startswith": The key must start with the pattern
- "endswith": The key must end with the pattern
Args:
original_state_dict: The source state dictionary
loaded_state_dict: The target state dictionary
pattern: The pattern to search for in keys
match_type: How to match the pattern (exact, contains, startswith, endswith)
Returns:
list[str]: List of keys that were copied
"""
copied_keys = []
for key in original_state_dict:
should_copy = False
if match_type == "exact":
should_copy = key == pattern
elif match_type == "contains":
should_copy = pattern in key
elif match_type == "startswith":
should_copy = key.startswith(pattern)
elif match_type == "endswith":
should_copy = key.endswith(pattern)
if should_copy:
loaded_state_dict[key] = original_state_dict[key]
copied_keys.append(key)
return copied_keys
class CriticHead(nn.Module): class CriticHead(nn.Module):
def __init__( def __init__(
self, self,