Handle new config with sac

This commit is contained in:
AdilZouitine 2025-03-24 20:19:28 +00:00
parent b2025b852c
commit f483931fc0
4 changed files with 242 additions and 240 deletions

View File

@ -55,10 +55,9 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
return PI0Policy
elif name == "sac":
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.sac.modeling_sac import SACPolicy
return SACPolicy, SACConfig
return SACPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")

View File

@ -79,28 +79,38 @@ def create_stats_buffers(
)
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
if stats:
if isinstance(stats[key]["mean"], np.ndarray):
if stats and key in stats:
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" not in stats[key] or "std" not in stats[key]:
raise ValueError(f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization")
if isinstance(stats[key]["mean"], np.ndarray):
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead.")
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" not in stats[key] or "max" not in stats[key]:
raise ValueError(f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization")
if isinstance(stats[key]["min"], np.ndarray):
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
elif isinstance(stats[key]["min"], torch.Tensor):
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
type_ = type(stats[key]["min"])
raise ValueError(f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead.")
stats_buffers[key] = buffer
return stats_buffers
@ -155,6 +165,7 @@ class Normalize(nn.Module):
for key, ft in self.features.items():
if key not in batch:
# FIXME(aliberts, rcadene): This might lead to silent fail!
# NOTE: (azouitine) This continues help us for instantiation SACPolicy
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)

View File

@ -18,57 +18,82 @@
from dataclasses import dataclass, field
from typing import Any
from lerobot.common.optim.optimizers import MultiAdamConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("sac")
@dataclass
class SACConfig:
input_shapes: dict[str, list[int]] = field(
class SACConfig(PreTrainedConfig):
"""Configuration class for Soft Actor-Critic (SAC) policy.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy.
normalization_mapping: Mapping from feature types to normalization modes.
camera_number: Number of cameras to use.
storage_device: Device to use for storage.
vision_encoder_name: Name of the vision encoder to use.
freeze_vision_encoder: Whether to freeze the vision encoder.
image_encoder_hidden_dim: Hidden dimension for the image encoder.
shared_encoder: Whether to use a shared encoder.
discount: Discount factor for the RL algorithm.
temperature_init: Initial temperature for entropy regularization.
num_critics: Number of critic networks.
num_subsample_critics: Number of critics to subsample.
critic_lr: Learning rate for critic networks.
actor_lr: Learning rate for actor network.
temperature_lr: Learning rate for temperature parameter.
critic_target_update_weight: Weight for soft target updates.
utd_ratio: Update-to-data ratio (>1 to enable).
state_encoder_hidden_dim: Hidden dimension for state encoder.
latent_dim: Dimension of latent representation.
target_entropy: Target entropy for automatic temperature tuning.
use_backup_entropy: Whether to use backup entropy.
grad_clip_norm: Gradient clipping norm.
critic_network_kwargs: Additional arguments for critic networks.
actor_network_kwargs: Additional arguments for actor network.
policy_kwargs: Additional arguments for policy.
"""
# Input / output structure
n_obs_steps: int = 1
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"observation.image": [3, 84, 84],
"observation.state": [4],
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ENV": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [2],
}
)
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
"observation.environment_state": "min_max",
}
)
input_normalization_params: dict[str, dict[str, list[float]]] = field(
dataset_stats: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"observation.image": {
"mean": [[0.485, 0.456, 0.406]],
"std": [[0.229, 0.224, 0.225]],
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
"observation.state": {
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
"action": {
"min": [0.0, 0.0, 0.0],
"max": [1.0, 1.0, 1.0],
},
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
output_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"action": {"min": [-1, -1], "max": [1, 1]},
}
)
# TODO: Move it outside of the config
actor_learner_config: dict[str, str | int] = field(
default_factory=lambda: {
"learner_host": "127.0.0.1",
"learner_port": 50051,
}
)
camera_number: int = 1
# Architecture specifics
camera_number: int = 1
storage_device: str = "cpu"
# Add type annotations for these fields:
vision_encoder_name: str | None = field(default="helper2424/resnet10")
# Set to "helper2424/resnet10" for hil serl
vision_encoder_name: str | None = None
freeze_vision_encoder: bool = True
image_encoder_hidden_dim: int = 32
shared_encoder: bool = True
# SAC algorithm parameters
discount: float = 0.99
temperature_init: float = 1.0
num_critics: int = 2
@ -83,6 +108,8 @@ class SACConfig:
target_entropy: float | None = None
use_backup_entropy: bool = True
grad_clip_norm: float = 40.0
# Network configuration
critic_network_kwargs: dict[str, Any] = field(
default_factory=lambda: {
"hidden_dims": [256, 256],
@ -104,3 +131,52 @@ class SACConfig:
"init_final": 0.05,
}
)
# Deprecated, kept for backward compatibility
actor_learner_config: dict[str, str | int] = field(
default_factory=lambda: {
"learner_host": "127.0.0.1",
"learner_port": 50051,
}
)
def __post_init__(self):
super().__post_init__()
# Any validation specific to SAC configuration
def get_optimizer_preset(self) -> MultiAdamConfig:
return MultiAdamConfig(
weight_decay=0.0,
optimizer_groups={
"actor": {"lr": self.actor_lr},
"critic": {"lr": self.critic_lr},
"temperature": {"lr": self.temperature_lr},
},
)
def get_scheduler_preset(self) -> None:
return None
def validate_features(self) -> None:
# TODO: Maybe we should remove this raise?
if len(self.image_features) == 0:
raise ValueError("You must provide at least one image among the inputs.")
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return [0] # SAC typically predicts one action at a time
@property
def reward_delta_indices(self) -> None:
return None
if __name__ == "__main__":
import draccus
config = SACConfig()
draccus.set_config_type("json")
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )

View File

@ -29,18 +29,17 @@ import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.utils import get_device_from_parameters
class SACPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "RL", "SAC"],
PreTrainedPolicy,
):
config_class = SACConfig
name = "sac"
def __init__(
@ -48,35 +47,33 @@ class SACPolicy(
config: SACConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
if config is None:
config = SACConfig()
super().__init__(config)
config.validate_features()
self.config = config
if config.input_normalization_modes is not None:
if config.dataset_stats is not None:
input_normalization_params = _convert_normalization_params_to_tensor(
config.input_normalization_params
config.dataset_stats
)
self.normalize_inputs = Normalize(
config.input_shapes,
config.input_normalization_modes,
config.input_features,
config.normalization_mapping,
input_normalization_params,
)
else:
self.normalize_inputs = nn.Identity()
output_normalization_params = _convert_normalization_params_to_tensor(
config.output_normalization_params
config.dataset_stats
)
# HACK: This is hacky and should be removed
dataset_stats = dataset_stats or output_normalization_params
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
config.output_features, config.normalization_mapping, dataset_stats
)
# NOTE: For images the encoder should be shared between the actor and critic
@ -90,7 +87,7 @@ class SACPolicy(
# Create a list of critic heads
critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
@ -105,7 +102,7 @@ class SACPolicy(
# Create target critic heads as deepcopies of the original critic heads
target_critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
@ -125,12 +122,12 @@ class SACPolicy(
self.actor = Policy(
encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
action_dim=config.output_features["action"].shape[0],
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@ -140,102 +137,13 @@ class SACPolicy(
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
self.temperature = self.log_alpha.exp().item()
def _save_pretrained(self, save_directory):
"""Custom save method to handle TensorDict properly"""
import json
import os
from dataclasses import asdict
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
from safetensors.torch import save_model
save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE))
# Save config
config_dict = asdict(self.config)
with open(os.path.join(save_directory, CONFIG_NAME), "w") as f:
json.dump(config_dict, f, indent=2)
print(f"Saved config to {os.path.join(save_directory, CONFIG_NAME)}")
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool,
proxies: Optional[Dict],
resume_download: Optional[bool],
local_files_only: bool,
token: Optional[Union[str, bool]],
map_location: str = "cpu",
strict: bool = False,
**model_kwargs,
) -> "SACPolicy":
"""Custom load method to handle loading SAC policy from saved files"""
import json
import os
from pathlib import Path
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
from safetensors.torch import load_model
from lerobot.common.policies.sac.configuration_sac import SACConfig
# Check if model_id is a local path or a hub model ID
if os.path.isdir(model_id):
model_path = Path(model_id)
safetensors_file = os.path.join(model_path, SAFETENSORS_SINGLE_FILE)
config_file = os.path.join(model_path, CONFIG_NAME)
else:
# Download the safetensors file from the hub
safetensors_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
# Download the config file
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except Exception:
config_file = None
# Load or create config
if config_file and os.path.exists(config_file):
# Load config from file
with open(config_file) as f:
config_dict = json.load(f)
config = SACConfig(**config_dict)
else:
# Use the provided config or create a default one
config = model_kwargs.get("config", SACConfig())
# Create a new instance with the loaded config
model = cls(config=config)
# Load state dict from safetensors file
if os.path.exists(safetensors_file):
load_model(model, filename=safetensors_file, device=map_location)
return model
def get_optim_params(self) -> dict:
return {
"actor": self.actor.parameters_to_optimize,
"critic": self.critic_ensemble.parameters_to_optimize,
"temperature": self.log_alpha,
}
def reset(self):
"""Reset the policy"""
@ -667,7 +575,7 @@ class SACObservationEncoder(nn.Module):
self.parameters_to_optimize = []
self.aggregation_size: int = 0
if any("observation.image" in key for key in config.input_shapes):
if any("observation.image" in key for key in config.input_features):
self.camera_number = config.camera_number
if self.config.vision_encoder_name is not None:
@ -682,12 +590,12 @@ class SACObservationEncoder(nn.Module):
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
if "observation.state" in config.input_shapes:
if "observation.state" in config.input_features:
self.state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_shapes["observation.state"][0],
in_features=config.input_features["observation.state"].shape[0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
@ -697,10 +605,10 @@ class SACObservationEncoder(nn.Module):
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
if "observation.environment_state" in config.input_shapes:
if "observation.environment_state" in config.input_features:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_shapes["observation.environment_state"][0],
in_features=config.input_features["observation.environment_state"].shape[0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
@ -727,9 +635,9 @@ class SACObservationEncoder(nn.Module):
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
feat.extend(embeddings_chunks)
if "observation.environment_state" in self.config.input_shapes:
if "observation.environment_state" in self.config.input_features:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
if "observation.state" in self.config.input_features:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
features = torch.cat(tensors=feat, dim=-1)
@ -744,11 +652,11 @@ class SACObservationEncoder(nn.Module):
class DefaultImageEncoder(nn.Module):
def __init__(self, config):
def __init__(self, config: SACConfig):
super().__init__()
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
in_channels=config.input_shapes["observation.image"][0],
in_channels=config.input_features["observation.image"].shape[0],
out_channels=config.image_encoder_hidden_dim,
kernel_size=7,
stride=2,
@ -776,7 +684,7 @@ class DefaultImageEncoder(nn.Module):
),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
dummy_batch = torch.zeros(1, *config.input_features["observation.image"].shape)
with torch.inference_mode():
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
@ -793,7 +701,7 @@ class DefaultImageEncoder(nn.Module):
class PretrainedImageEncoder(nn.Module):
def __init__(self, config):
def __init__(self, config: SACConfig):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
@ -803,7 +711,7 @@ class PretrainedImageEncoder(nn.Module):
nn.Tanh(),
)
def _load_pretrained_vision_encoder(self, config):
def _load_pretrained_vision_encoder(self, config: SACConfig):
"""Set up CNN encoder"""
from transformers import AutoModel
@ -857,73 +765,73 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
if __name__ == "__main__":
# Benchmark the CriticEnsemble performance
import time
# # Benchmark the CriticEnsemble performance
# import time
# Configuration
num_critics = 10
batch_size = 32
action_dim = 7
obs_dim = 64
hidden_dims = [256, 256]
num_iterations = 100
# # Configuration
# num_critics = 10
# batch_size = 32
# action_dim = 7
# obs_dim = 64
# hidden_dims = [256, 256]
# num_iterations = 100
print("Creating test environment...")
# print("Creating test environment...")
# Create a simple dummy encoder
class DummyEncoder(nn.Module):
def __init__(self):
super().__init__()
self.output_dim = obs_dim
self.parameters_to_optimize = []
# # Create a simple dummy encoder
# class DummyEncoder(nn.Module):
# def __init__(self):
# super().__init__()
# self.output_dim = obs_dim
# self.parameters_to_optimize = []
def forward(self, obs):
# Just return a random tensor of the right shape
# In practice, this would encode the observations
return torch.randn(batch_size, obs_dim, device=device)
# def forward(self, obs):
# # Just return a random tensor of the right shape
# # In practice, this would encode the observations
# return torch.randn(batch_size, obs_dim, device=device)
# Create critic heads
print(f"Creating {num_critics} critic heads...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
critic_heads = [
CriticHead(
input_dim=obs_dim + action_dim,
hidden_dims=hidden_dims,
).to(device)
for _ in range(num_critics)
]
# # Create critic heads
# print(f"Creating {num_critics} critic heads...")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# critic_heads = [
# CriticHead(
# input_dim=obs_dim + action_dim,
# hidden_dims=hidden_dims,
# ).to(device)
# for _ in range(num_critics)
# ]
# Create the critic ensemble
print("Creating CriticEnsemble...")
critic_ensemble = CriticEnsemble(
encoder=DummyEncoder().to(device),
ensemble=critic_heads,
output_normalization=nn.Identity(),
).to(device)
# # Create the critic ensemble
# print("Creating CriticEnsemble...")
# critic_ensemble = CriticEnsemble(
# encoder=DummyEncoder().to(device),
# ensemble=critic_heads,
# output_normalization=nn.Identity(),
# ).to(device)
# Create random input data
print("Creating input data...")
obs_dict = {
"observation.state": torch.randn(batch_size, obs_dim, device=device),
}
actions = torch.randn(batch_size, action_dim, device=device)
# # Create random input data
# print("Creating input data...")
# obs_dict = {
# "observation.state": torch.randn(batch_size, obs_dim, device=device),
# }
# actions = torch.randn(batch_size, action_dim, device=device)
# Warmup run
print("Warming up...")
_ = critic_ensemble(obs_dict, actions)
# # Warmup run
# print("Warming up...")
# _ = critic_ensemble(obs_dict, actions)
# Time the forward pass
print(f"Running benchmark with {num_iterations} iterations...")
start_time = time.perf_counter()
for _ in range(num_iterations):
q_values = critic_ensemble(obs_dict, actions)
end_time = time.perf_counter()
# # Time the forward pass
# print(f"Running benchmark with {num_iterations} iterations...")
# start_time = time.perf_counter()
# for _ in range(num_iterations):
# q_values = critic_ensemble(obs_dict, actions)
# end_time = time.perf_counter()
# Print results
elapsed_time = end_time - start_time
print(f"Total time: {elapsed_time:.4f} seconds")
print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
# # Print results
# elapsed_time = end_time - start_time
# print(f"Total time: {elapsed_time:.4f} seconds")
# print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
# print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
# Verify that all critic heads produce different outputs
# This confirms each critic head is unique
@ -932,3 +840,11 @@ if __name__ == "__main__":
# for j in range(i + 1, num_critics):
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
import draccus
from lerobot.configs import parser
@parser.wrap()
def main(config: SACConfig):
policy = SACPolicy(config=config)
print("yolo")
main()