Add custom save and load methods for SAC policy
- Implement `_save_pretrained` method to handle TensorDict state saving - Add `_from_pretrained` class method for loading SAC policy from files - Create utility function `find_and_copy_params` to handle parameter copying
This commit is contained in:
parent
7e3e1ce173
commit
a3ef7dc6c3
|
@ -18,7 +18,8 @@
|
||||||
# TODO: (1) better device management
|
# TODO: (1) better device management
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Callable, Optional, Tuple
|
from typing import Callable, Optional, Tuple, Union, Dict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -142,6 +143,131 @@ class SACPolicy(
|
||||||
self.log_alpha = nn.Parameter(torch.tensor([0.0]))
|
self.log_alpha = nn.Parameter(torch.tensor([0.0]))
|
||||||
self.temperature = self.log_alpha.exp().item()
|
self.temperature = self.log_alpha.exp().item()
|
||||||
|
|
||||||
|
def _save_pretrained(self, save_directory):
|
||||||
|
"""Custom save method to handle TensorDict properly"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from dataclasses import asdict
|
||||||
|
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
# NOTE: Using tensordict.from_modules in the model to batch the inference using torch.vmap
|
||||||
|
# implies one side effect: the __batch_size parameters are saved in the state_dict
|
||||||
|
# __batch_size is torch.Size or safetensor save only torch.Tensor
|
||||||
|
# so we need to filter them out before saving
|
||||||
|
simplified_state_dict = {}
|
||||||
|
|
||||||
|
for name, param in self.named_parameters():
|
||||||
|
simplified_state_dict[name] = param
|
||||||
|
save_file(
|
||||||
|
simplified_state_dict, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save config
|
||||||
|
config_dict = asdict(self.config)
|
||||||
|
with open(os.path.join(save_directory, CONFIG_NAME), "w") as f:
|
||||||
|
json.dump(config_dict, f, indent=2)
|
||||||
|
print(f"Saved config to {os.path.join(save_directory, CONFIG_NAME)}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_pretrained(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str],
|
||||||
|
cache_dir: Optional[Union[str, Path]],
|
||||||
|
force_download: bool,
|
||||||
|
proxies: Optional[Dict],
|
||||||
|
resume_download: Optional[bool],
|
||||||
|
local_files_only: bool,
|
||||||
|
token: Optional[Union[str, bool]],
|
||||||
|
map_location: str = "cpu",
|
||||||
|
strict: bool = False,
|
||||||
|
**model_kwargs,
|
||||||
|
) -> "SACPolicy":
|
||||||
|
"""Custom load method to handle loading SAC policy from saved files"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||||
|
|
||||||
|
# Check if model_id is a local path or a hub model ID
|
||||||
|
if os.path.isdir(model_id):
|
||||||
|
model_path = Path(model_id)
|
||||||
|
safetensors_file = os.path.join(model_path, SAFETENSORS_SINGLE_FILE)
|
||||||
|
config_file = os.path.join(model_path, CONFIG_NAME)
|
||||||
|
else:
|
||||||
|
# Download the safetensors file from the hub
|
||||||
|
safetensors_file = hf_hub_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
filename=SAFETENSORS_SINGLE_FILE,
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
token=token,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
# Download the config file
|
||||||
|
try:
|
||||||
|
config_file = hf_hub_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
filename=CONFIG_NAME,
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
token=token,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
config_file = None
|
||||||
|
|
||||||
|
# Load or create config
|
||||||
|
if config_file and os.path.exists(config_file):
|
||||||
|
# Load config from file
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
config_dict = json.load(f)
|
||||||
|
config = SACConfig(**config_dict)
|
||||||
|
else:
|
||||||
|
# Use the provided config or create a default one
|
||||||
|
config = model_kwargs.get("config", SACConfig())
|
||||||
|
|
||||||
|
# Create a new instance with the loaded config
|
||||||
|
model = cls(config=config)
|
||||||
|
|
||||||
|
# Load state dict from safetensors file
|
||||||
|
if os.path.exists(safetensors_file):
|
||||||
|
# Note: The load_file function returns a dict with the parameters, but __batch_size
|
||||||
|
# is not loaded so we need to copy it from the model state_dict
|
||||||
|
# Load the parameters only
|
||||||
|
loaded_state_dict = load_file(safetensors_file, device=map_location)
|
||||||
|
|
||||||
|
# Copy batch size parameters
|
||||||
|
find_and_copy_params(
|
||||||
|
original_state_dict=model.state_dict(),
|
||||||
|
loaded_state_dict=loaded_state_dict,
|
||||||
|
pattern="__batch_size",
|
||||||
|
match_type="endswith",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy normalization buffer parameters
|
||||||
|
find_and_copy_params(
|
||||||
|
original_state_dict=model.state_dict(),
|
||||||
|
loaded_state_dict=loaded_state_dict,
|
||||||
|
pattern="_orig_mod.output_normalization.buffer_action",
|
||||||
|
match_type="contains",
|
||||||
|
)
|
||||||
|
|
||||||
|
model.load_state_dict(loaded_state_dict, strict=False)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the policy"""
|
"""Reset the policy"""
|
||||||
pass
|
pass
|
||||||
|
@ -276,6 +402,9 @@ class SACPolicy(
|
||||||
|
|
||||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||||
|
|
||||||
|
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
|
||||||
|
actions_pi = self.unnormalize_outputs({"action": actions_pi})["action"]
|
||||||
|
|
||||||
q_preds = self.critic_forward(
|
q_preds = self.critic_forward(
|
||||||
observations,
|
observations,
|
||||||
actions_pi,
|
actions_pi,
|
||||||
|
@ -334,6 +463,50 @@ class MLP(nn.Module):
|
||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
def find_and_copy_params(
|
||||||
|
original_state_dict: dict[str, torch.Tensor],
|
||||||
|
loaded_state_dict: dict[str, torch.Tensor],
|
||||||
|
pattern: str,
|
||||||
|
match_type: str = "contains",
|
||||||
|
) -> list[str]:
|
||||||
|
"""Find and copy parameters from original state dict to loaded state dict based on a pattern.
|
||||||
|
|
||||||
|
This function can search for keys in different ways based on the match_type:
|
||||||
|
- "exact": The key must exactly match the pattern
|
||||||
|
- "contains": The key must contain the pattern anywhere
|
||||||
|
- "startswith": The key must start with the pattern
|
||||||
|
- "endswith": The key must end with the pattern
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_state_dict: The source state dictionary
|
||||||
|
loaded_state_dict: The target state dictionary
|
||||||
|
pattern: The pattern to search for in keys
|
||||||
|
match_type: How to match the pattern (exact, contains, startswith, endswith)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: List of keys that were copied
|
||||||
|
"""
|
||||||
|
copied_keys = []
|
||||||
|
|
||||||
|
for key in original_state_dict:
|
||||||
|
should_copy = False
|
||||||
|
|
||||||
|
if match_type == "exact":
|
||||||
|
should_copy = key == pattern
|
||||||
|
elif match_type == "contains":
|
||||||
|
should_copy = pattern in key
|
||||||
|
elif match_type == "startswith":
|
||||||
|
should_copy = key.startswith(pattern)
|
||||||
|
elif match_type == "endswith":
|
||||||
|
should_copy = key.endswith(pattern)
|
||||||
|
|
||||||
|
if should_copy:
|
||||||
|
loaded_state_dict[key] = original_state_dict[key]
|
||||||
|
copied_keys.append(key)
|
||||||
|
|
||||||
|
return copied_keys
|
||||||
|
|
||||||
|
|
||||||
class CriticHead(nn.Module):
|
class CriticHead(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue