Add custom save and load methods for SAC policy
- Implement `_save_pretrained` method to handle TensorDict state saving - Add `_from_pretrained` class method for loading SAC policy from files - Create utility function `find_and_copy_params` to handle parameter copying
This commit is contained in:
parent
7e3e1ce173
commit
a3ef7dc6c3
|
@ -18,7 +18,8 @@
|
|||
# TODO: (1) better device management
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Optional, Tuple
|
||||
from typing import Callable, Optional, Tuple, Union, Dict
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
|
@ -142,6 +143,131 @@ class SACPolicy(
|
|||
self.log_alpha = nn.Parameter(torch.tensor([0.0]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
def _save_pretrained(self, save_directory):
|
||||
"""Custom save method to handle TensorDict properly"""
|
||||
import os
|
||||
import json
|
||||
from dataclasses import asdict
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
|
||||
from safetensors.torch import save_file
|
||||
|
||||
# NOTE: Using tensordict.from_modules in the model to batch the inference using torch.vmap
|
||||
# implies one side effect: the __batch_size parameters are saved in the state_dict
|
||||
# __batch_size is torch.Size or safetensor save only torch.Tensor
|
||||
# so we need to filter them out before saving
|
||||
simplified_state_dict = {}
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
simplified_state_dict[name] = param
|
||||
save_file(
|
||||
simplified_state_dict, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE)
|
||||
)
|
||||
|
||||
# Save config
|
||||
config_dict = asdict(self.config)
|
||||
with open(os.path.join(save_directory, CONFIG_NAME), "w") as f:
|
||||
json.dump(config_dict, f, indent=2)
|
||||
print(f"Saved config to {os.path.join(save_directory, CONFIG_NAME)}")
|
||||
|
||||
@classmethod
|
||||
def _from_pretrained(
|
||||
cls,
|
||||
*,
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
cache_dir: Optional[Union[str, Path]],
|
||||
force_download: bool,
|
||||
proxies: Optional[Dict],
|
||||
resume_download: Optional[bool],
|
||||
local_files_only: bool,
|
||||
token: Optional[Union[str, bool]],
|
||||
map_location: str = "cpu",
|
||||
strict: bool = False,
|
||||
**model_kwargs,
|
||||
) -> "SACPolicy":
|
||||
"""Custom load method to handle loading SAC policy from saved files"""
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
|
||||
from safetensors.torch import load_file
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
|
||||
# Check if model_id is a local path or a hub model ID
|
||||
if os.path.isdir(model_id):
|
||||
model_path = Path(model_id)
|
||||
safetensors_file = os.path.join(model_path, SAFETENSORS_SINGLE_FILE)
|
||||
config_file = os.path.join(model_path, CONFIG_NAME)
|
||||
else:
|
||||
# Download the safetensors file from the hub
|
||||
safetensors_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=SAFETENSORS_SINGLE_FILE,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
# Download the config file
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except Exception:
|
||||
config_file = None
|
||||
|
||||
# Load or create config
|
||||
if config_file and os.path.exists(config_file):
|
||||
# Load config from file
|
||||
with open(config_file, "r") as f:
|
||||
config_dict = json.load(f)
|
||||
config = SACConfig(**config_dict)
|
||||
else:
|
||||
# Use the provided config or create a default one
|
||||
config = model_kwargs.get("config", SACConfig())
|
||||
|
||||
# Create a new instance with the loaded config
|
||||
model = cls(config=config)
|
||||
|
||||
# Load state dict from safetensors file
|
||||
if os.path.exists(safetensors_file):
|
||||
# Note: The load_file function returns a dict with the parameters, but __batch_size
|
||||
# is not loaded so we need to copy it from the model state_dict
|
||||
# Load the parameters only
|
||||
loaded_state_dict = load_file(safetensors_file, device=map_location)
|
||||
|
||||
# Copy batch size parameters
|
||||
find_and_copy_params(
|
||||
original_state_dict=model.state_dict(),
|
||||
loaded_state_dict=loaded_state_dict,
|
||||
pattern="__batch_size",
|
||||
match_type="endswith",
|
||||
)
|
||||
|
||||
# Copy normalization buffer parameters
|
||||
find_and_copy_params(
|
||||
original_state_dict=model.state_dict(),
|
||||
loaded_state_dict=loaded_state_dict,
|
||||
pattern="_orig_mod.output_normalization.buffer_action",
|
||||
match_type="contains",
|
||||
)
|
||||
|
||||
model.load_state_dict(loaded_state_dict, strict=False)
|
||||
|
||||
return model
|
||||
|
||||
def reset(self):
|
||||
"""Reset the policy"""
|
||||
pass
|
||||
|
@ -276,6 +402,9 @@ class SACPolicy(
|
|||
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
|
||||
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
|
||||
actions_pi = self.unnormalize_outputs({"action": actions_pi})["action"]
|
||||
|
||||
q_preds = self.critic_forward(
|
||||
observations,
|
||||
actions_pi,
|
||||
|
@ -334,6 +463,50 @@ class MLP(nn.Module):
|
|||
return self.net(x)
|
||||
|
||||
|
||||
def find_and_copy_params(
|
||||
original_state_dict: dict[str, torch.Tensor],
|
||||
loaded_state_dict: dict[str, torch.Tensor],
|
||||
pattern: str,
|
||||
match_type: str = "contains",
|
||||
) -> list[str]:
|
||||
"""Find and copy parameters from original state dict to loaded state dict based on a pattern.
|
||||
|
||||
This function can search for keys in different ways based on the match_type:
|
||||
- "exact": The key must exactly match the pattern
|
||||
- "contains": The key must contain the pattern anywhere
|
||||
- "startswith": The key must start with the pattern
|
||||
- "endswith": The key must end with the pattern
|
||||
|
||||
Args:
|
||||
original_state_dict: The source state dictionary
|
||||
loaded_state_dict: The target state dictionary
|
||||
pattern: The pattern to search for in keys
|
||||
match_type: How to match the pattern (exact, contains, startswith, endswith)
|
||||
|
||||
Returns:
|
||||
list[str]: List of keys that were copied
|
||||
"""
|
||||
copied_keys = []
|
||||
|
||||
for key in original_state_dict:
|
||||
should_copy = False
|
||||
|
||||
if match_type == "exact":
|
||||
should_copy = key == pattern
|
||||
elif match_type == "contains":
|
||||
should_copy = pattern in key
|
||||
elif match_type == "startswith":
|
||||
should_copy = key.startswith(pattern)
|
||||
elif match_type == "endswith":
|
||||
should_copy = key.endswith(pattern)
|
||||
|
||||
if should_copy:
|
||||
loaded_state_dict[key] = original_state_dict[key]
|
||||
copied_keys.append(key)
|
||||
|
||||
return copied_keys
|
||||
|
||||
|
||||
class CriticHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue