Handle new config with sac

This commit is contained in:
AdilZouitine 2025-03-24 20:19:28 +00:00
parent b2025b852c
commit f483931fc0
4 changed files with 242 additions and 240 deletions

View File

@ -55,10 +55,9 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
return PI0Policy return PI0Policy
elif name == "sac": elif name == "sac":
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy
return SACPolicy, SACConfig return SACPolicy
else: else:
raise NotImplementedError(f"Policy with name {name} is not implemented.") raise NotImplementedError(f"Policy with name {name} is not implemented.")

View File

@ -79,28 +79,38 @@ def create_stats_buffers(
) )
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch) # TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
if stats: if stats and key in stats:
if isinstance(stats[key]["mean"], np.ndarray): if norm_mode is NormalizationMode.MEAN_STD:
if norm_mode is NormalizationMode.MEAN_STD: if "mean" not in stats[key] or "std" not in stats[key]:
raise ValueError(f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization")
if isinstance(stats[key]["mean"], np.ndarray):
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32) buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32) buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX: elif isinstance(stats[key]["mean"], torch.Tensor):
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) # tensors anywhere (for example, when we use the same stats for normalization and
elif isinstance(stats[key]["mean"], torch.Tensor): # unnormalization). See the logic here
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32) buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32) buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX: else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead.")
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" not in stats[key] or "max" not in stats[key]:
raise ValueError(f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization")
if isinstance(stats[key]["min"], np.ndarray):
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
elif isinstance(stats[key]["min"], torch.Tensor):
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32) buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32) buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
else: else:
type_ = type(stats[key]["mean"]) type_ = type(stats[key]["min"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.") raise ValueError(f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead.")
stats_buffers[key] = buffer stats_buffers[key] = buffer
return stats_buffers return stats_buffers
@ -155,6 +165,7 @@ class Normalize(nn.Module):
for key, ft in self.features.items(): for key, ft in self.features.items():
if key not in batch: if key not in batch:
# FIXME(aliberts, rcadene): This might lead to silent fail! # FIXME(aliberts, rcadene): This might lead to silent fail!
# NOTE: (azouitine) This continues help us for instantiation SACPolicy
continue continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)

View File

@ -18,57 +18,82 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
from lerobot.common.optim.optimizers import MultiAdamConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("sac")
@dataclass @dataclass
class SACConfig: class SACConfig(PreTrainedConfig):
input_shapes: dict[str, list[int]] = field( """Configuration class for Soft Actor-Critic (SAC) policy.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy.
normalization_mapping: Mapping from feature types to normalization modes.
camera_number: Number of cameras to use.
storage_device: Device to use for storage.
vision_encoder_name: Name of the vision encoder to use.
freeze_vision_encoder: Whether to freeze the vision encoder.
image_encoder_hidden_dim: Hidden dimension for the image encoder.
shared_encoder: Whether to use a shared encoder.
discount: Discount factor for the RL algorithm.
temperature_init: Initial temperature for entropy regularization.
num_critics: Number of critic networks.
num_subsample_critics: Number of critics to subsample.
critic_lr: Learning rate for critic networks.
actor_lr: Learning rate for actor network.
temperature_lr: Learning rate for temperature parameter.
critic_target_update_weight: Weight for soft target updates.
utd_ratio: Update-to-data ratio (>1 to enable).
state_encoder_hidden_dim: Hidden dimension for state encoder.
latent_dim: Dimension of latent representation.
target_entropy: Target entropy for automatic temperature tuning.
use_backup_entropy: Whether to use backup entropy.
grad_clip_norm: Gradient clipping norm.
critic_network_kwargs: Additional arguments for critic networks.
actor_network_kwargs: Additional arguments for actor network.
policy_kwargs: Additional arguments for policy.
"""
# Input / output structure
n_obs_steps: int = 1
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: { default_factory=lambda: {
"observation.image": [3, 84, 84], "VISUAL": NormalizationMode.MEAN_STD,
"observation.state": [4], "STATE": NormalizationMode.MIN_MAX,
"ENV": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
} }
) )
output_shapes: dict[str, list[int]] = field( dataset_stats: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"action": [2],
}
)
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
"observation.environment_state": "min_max",
}
)
input_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: { default_factory=lambda: {
"observation.image": { "observation.image": {
"mean": [[0.485, 0.456, 0.406]], "mean": [0.485, 0.456, 0.406],
"std": [[0.229, 0.224, 0.225]], "std": [0.229, 0.224, 0.225],
},
"observation.state": {
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
"action": {
"min": [0.0, 0.0, 0.0],
"max": [1.0, 1.0, 1.0],
}, },
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
} }
) )
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
output_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"action": {"min": [-1, -1], "max": [1, 1]},
}
)
# TODO: Move it outside of the config
actor_learner_config: dict[str, str | int] = field(
default_factory=lambda: {
"learner_host": "127.0.0.1",
"learner_port": 50051,
}
)
camera_number: int = 1
# Architecture specifics
camera_number: int = 1
storage_device: str = "cpu" storage_device: str = "cpu"
# Add type annotations for these fields: # Set to "helper2424/resnet10" for hil serl
vision_encoder_name: str | None = field(default="helper2424/resnet10") vision_encoder_name: str | None = None
freeze_vision_encoder: bool = True freeze_vision_encoder: bool = True
image_encoder_hidden_dim: int = 32 image_encoder_hidden_dim: int = 32
shared_encoder: bool = True shared_encoder: bool = True
# SAC algorithm parameters
discount: float = 0.99 discount: float = 0.99
temperature_init: float = 1.0 temperature_init: float = 1.0
num_critics: int = 2 num_critics: int = 2
@ -83,6 +108,8 @@ class SACConfig:
target_entropy: float | None = None target_entropy: float | None = None
use_backup_entropy: bool = True use_backup_entropy: bool = True
grad_clip_norm: float = 40.0 grad_clip_norm: float = 40.0
# Network configuration
critic_network_kwargs: dict[str, Any] = field( critic_network_kwargs: dict[str, Any] = field(
default_factory=lambda: { default_factory=lambda: {
"hidden_dims": [256, 256], "hidden_dims": [256, 256],
@ -104,3 +131,52 @@ class SACConfig:
"init_final": 0.05, "init_final": 0.05,
} }
) )
# Deprecated, kept for backward compatibility
actor_learner_config: dict[str, str | int] = field(
default_factory=lambda: {
"learner_host": "127.0.0.1",
"learner_port": 50051,
}
)
def __post_init__(self):
super().__post_init__()
# Any validation specific to SAC configuration
def get_optimizer_preset(self) -> MultiAdamConfig:
return MultiAdamConfig(
weight_decay=0.0,
optimizer_groups={
"actor": {"lr": self.actor_lr},
"critic": {"lr": self.critic_lr},
"temperature": {"lr": self.temperature_lr},
},
)
def get_scheduler_preset(self) -> None:
return None
def validate_features(self) -> None:
# TODO: Maybe we should remove this raise?
if len(self.image_features) == 0:
raise ValueError("You must provide at least one image among the inputs.")
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return [0] # SAC typically predicts one action at a time
@property
def reward_delta_indices(self) -> None:
return None
if __name__ == "__main__":
import draccus
config = SACConfig()
draccus.set_config_type("json")
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )

View File

@ -29,18 +29,17 @@ import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor from torch import Tensor
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.policies.utils import get_device_from_parameters
class SACPolicy( class SACPolicy(
nn.Module, PreTrainedPolicy,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "RL", "SAC"],
): ):
config_class = SACConfig
name = "sac" name = "sac"
def __init__( def __init__(
@ -48,35 +47,33 @@ class SACPolicy(
config: SACConfig | None = None, config: SACConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None,
): ):
super().__init__() super().__init__(config)
config.validate_features()
if config is None:
config = SACConfig()
self.config = config self.config = config
if config.input_normalization_modes is not None: if config.dataset_stats is not None:
input_normalization_params = _convert_normalization_params_to_tensor( input_normalization_params = _convert_normalization_params_to_tensor(
config.input_normalization_params config.dataset_stats
) )
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(
config.input_shapes, config.input_features,
config.input_normalization_modes, config.normalization_mapping,
input_normalization_params, input_normalization_params,
) )
else: else:
self.normalize_inputs = nn.Identity() self.normalize_inputs = nn.Identity()
output_normalization_params = _convert_normalization_params_to_tensor( output_normalization_params = _convert_normalization_params_to_tensor(
config.output_normalization_params config.dataset_stats
) )
# HACK: This is hacky and should be removed # HACK: This is hacky and should be removed
dataset_stats = dataset_stats or output_normalization_params dataset_stats = dataset_stats or output_normalization_params
self.normalize_targets = Normalize( self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_features, config.normalization_mapping, dataset_stats
) )
self.unnormalize_outputs = Unnormalize( self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_features, config.normalization_mapping, dataset_stats
) )
# NOTE: For images the encoder should be shared between the actor and critic # NOTE: For images the encoder should be shared between the actor and critic
@ -90,7 +87,7 @@ class SACPolicy(
# Create a list of critic heads # Create a list of critic heads
critic_heads = [ critic_heads = [
CriticHead( CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
**config.critic_network_kwargs, **config.critic_network_kwargs,
) )
for _ in range(config.num_critics) for _ in range(config.num_critics)
@ -105,7 +102,7 @@ class SACPolicy(
# Create target critic heads as deepcopies of the original critic heads # Create target critic heads as deepcopies of the original critic heads
target_critic_heads = [ target_critic_heads = [
CriticHead( CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
**config.critic_network_kwargs, **config.critic_network_kwargs,
) )
for _ in range(config.num_critics) for _ in range(config.num_critics)
@ -125,12 +122,12 @@ class SACPolicy(
self.actor = Policy( self.actor = Policy(
encoder=encoder_actor, encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs), network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0], action_dim=config.output_features["action"].shape[0],
encoder_is_shared=config.shared_encoder, encoder_is_shared=config.shared_encoder,
**config.policy_kwargs, **config.policy_kwargs,
) )
if config.target_entropy is None: if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2) config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@ -140,102 +137,13 @@ class SACPolicy(
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)])) self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
self.temperature = self.log_alpha.exp().item() self.temperature = self.log_alpha.exp().item()
def _save_pretrained(self, save_directory):
"""Custom save method to handle TensorDict properly"""
import json
import os
from dataclasses import asdict
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE def get_optim_params(self) -> dict:
from safetensors.torch import save_model return {
"actor": self.actor.parameters_to_optimize,
save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE)) "critic": self.critic_ensemble.parameters_to_optimize,
"temperature": self.log_alpha,
# Save config }
config_dict = asdict(self.config)
with open(os.path.join(save_directory, CONFIG_NAME), "w") as f:
json.dump(config_dict, f, indent=2)
print(f"Saved config to {os.path.join(save_directory, CONFIG_NAME)}")
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool,
proxies: Optional[Dict],
resume_download: Optional[bool],
local_files_only: bool,
token: Optional[Union[str, bool]],
map_location: str = "cpu",
strict: bool = False,
**model_kwargs,
) -> "SACPolicy":
"""Custom load method to handle loading SAC policy from saved files"""
import json
import os
from pathlib import Path
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
from safetensors.torch import load_model
from lerobot.common.policies.sac.configuration_sac import SACConfig
# Check if model_id is a local path or a hub model ID
if os.path.isdir(model_id):
model_path = Path(model_id)
safetensors_file = os.path.join(model_path, SAFETENSORS_SINGLE_FILE)
config_file = os.path.join(model_path, CONFIG_NAME)
else:
# Download the safetensors file from the hub
safetensors_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
# Download the config file
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except Exception:
config_file = None
# Load or create config
if config_file and os.path.exists(config_file):
# Load config from file
with open(config_file) as f:
config_dict = json.load(f)
config = SACConfig(**config_dict)
else:
# Use the provided config or create a default one
config = model_kwargs.get("config", SACConfig())
# Create a new instance with the loaded config
model = cls(config=config)
# Load state dict from safetensors file
if os.path.exists(safetensors_file):
load_model(model, filename=safetensors_file, device=map_location)
return model
def reset(self): def reset(self):
"""Reset the policy""" """Reset the policy"""
@ -667,7 +575,7 @@ class SACObservationEncoder(nn.Module):
self.parameters_to_optimize = [] self.parameters_to_optimize = []
self.aggregation_size: int = 0 self.aggregation_size: int = 0
if any("observation.image" in key for key in config.input_shapes): if any("observation.image" in key for key in config.input_features):
self.camera_number = config.camera_number self.camera_number = config.camera_number
if self.config.vision_encoder_name is not None: if self.config.vision_encoder_name is not None:
@ -682,12 +590,12 @@ class SACObservationEncoder(nn.Module):
freeze_image_encoder(self.image_enc_layers) freeze_image_encoder(self.image_enc_layers)
else: else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters()) self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
if "observation.state" in config.input_shapes: if "observation.state" in config.input_features:
self.state_enc_layers = nn.Sequential( self.state_enc_layers = nn.Sequential(
nn.Linear( nn.Linear(
in_features=config.input_shapes["observation.state"][0], in_features=config.input_features["observation.state"].shape[0],
out_features=config.latent_dim, out_features=config.latent_dim,
), ),
nn.LayerNorm(normalized_shape=config.latent_dim), nn.LayerNorm(normalized_shape=config.latent_dim),
@ -697,10 +605,10 @@ class SACObservationEncoder(nn.Module):
self.parameters_to_optimize += list(self.state_enc_layers.parameters()) self.parameters_to_optimize += list(self.state_enc_layers.parameters())
if "observation.environment_state" in config.input_shapes: if "observation.environment_state" in config.input_features:
self.env_state_enc_layers = nn.Sequential( self.env_state_enc_layers = nn.Sequential(
nn.Linear( nn.Linear(
in_features=config.input_shapes["observation.environment_state"][0], in_features=config.input_features["observation.environment_state"].shape[0],
out_features=config.latent_dim, out_features=config.latent_dim,
), ),
nn.LayerNorm(normalized_shape=config.latent_dim), nn.LayerNorm(normalized_shape=config.latent_dim),
@ -727,9 +635,9 @@ class SACObservationEncoder(nn.Module):
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys)) embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
feat.extend(embeddings_chunks) feat.extend(embeddings_chunks)
if "observation.environment_state" in self.config.input_shapes: if "observation.environment_state" in self.config.input_features:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes: if "observation.state" in self.config.input_features:
feat.append(self.state_enc_layers(obs_dict["observation.state"])) feat.append(self.state_enc_layers(obs_dict["observation.state"]))
features = torch.cat(tensors=feat, dim=-1) features = torch.cat(tensors=feat, dim=-1)
@ -744,11 +652,11 @@ class SACObservationEncoder(nn.Module):
class DefaultImageEncoder(nn.Module): class DefaultImageEncoder(nn.Module):
def __init__(self, config): def __init__(self, config: SACConfig):
super().__init__() super().__init__()
self.image_enc_layers = nn.Sequential( self.image_enc_layers = nn.Sequential(
nn.Conv2d( nn.Conv2d(
in_channels=config.input_shapes["observation.image"][0], in_channels=config.input_features["observation.image"].shape[0],
out_channels=config.image_encoder_hidden_dim, out_channels=config.image_encoder_hidden_dim,
kernel_size=7, kernel_size=7,
stride=2, stride=2,
@ -776,7 +684,7 @@ class DefaultImageEncoder(nn.Module):
), ),
nn.ReLU(), nn.ReLU(),
) )
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) dummy_batch = torch.zeros(1, *config.input_features["observation.image"].shape)
with torch.inference_mode(): with torch.inference_mode():
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:] self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend( self.image_enc_layers.extend(
@ -793,7 +701,7 @@ class DefaultImageEncoder(nn.Module):
class PretrainedImageEncoder(nn.Module): class PretrainedImageEncoder(nn.Module):
def __init__(self, config): def __init__(self, config: SACConfig):
super().__init__() super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config) self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
@ -803,7 +711,7 @@ class PretrainedImageEncoder(nn.Module):
nn.Tanh(), nn.Tanh(),
) )
def _load_pretrained_vision_encoder(self, config): def _load_pretrained_vision_encoder(self, config: SACConfig):
"""Set up CNN encoder""" """Set up CNN encoder"""
from transformers import AutoModel from transformers import AutoModel
@ -857,73 +765,73 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
if __name__ == "__main__": if __name__ == "__main__":
# Benchmark the CriticEnsemble performance # # Benchmark the CriticEnsemble performance
import time # import time
# Configuration # # Configuration
num_critics = 10 # num_critics = 10
batch_size = 32 # batch_size = 32
action_dim = 7 # action_dim = 7
obs_dim = 64 # obs_dim = 64
hidden_dims = [256, 256] # hidden_dims = [256, 256]
num_iterations = 100 # num_iterations = 100
print("Creating test environment...") # print("Creating test environment...")
# Create a simple dummy encoder # # Create a simple dummy encoder
class DummyEncoder(nn.Module): # class DummyEncoder(nn.Module):
def __init__(self): # def __init__(self):
super().__init__() # super().__init__()
self.output_dim = obs_dim # self.output_dim = obs_dim
self.parameters_to_optimize = [] # self.parameters_to_optimize = []
def forward(self, obs): # def forward(self, obs):
# Just return a random tensor of the right shape # # Just return a random tensor of the right shape
# In practice, this would encode the observations # # In practice, this would encode the observations
return torch.randn(batch_size, obs_dim, device=device) # return torch.randn(batch_size, obs_dim, device=device)
# Create critic heads # # Create critic heads
print(f"Creating {num_critics} critic heads...") # print(f"Creating {num_critics} critic heads...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
critic_heads = [ # critic_heads = [
CriticHead( # CriticHead(
input_dim=obs_dim + action_dim, # input_dim=obs_dim + action_dim,
hidden_dims=hidden_dims, # hidden_dims=hidden_dims,
).to(device) # ).to(device)
for _ in range(num_critics) # for _ in range(num_critics)
] # ]
# Create the critic ensemble # # Create the critic ensemble
print("Creating CriticEnsemble...") # print("Creating CriticEnsemble...")
critic_ensemble = CriticEnsemble( # critic_ensemble = CriticEnsemble(
encoder=DummyEncoder().to(device), # encoder=DummyEncoder().to(device),
ensemble=critic_heads, # ensemble=critic_heads,
output_normalization=nn.Identity(), # output_normalization=nn.Identity(),
).to(device) # ).to(device)
# Create random input data # # Create random input data
print("Creating input data...") # print("Creating input data...")
obs_dict = { # obs_dict = {
"observation.state": torch.randn(batch_size, obs_dim, device=device), # "observation.state": torch.randn(batch_size, obs_dim, device=device),
} # }
actions = torch.randn(batch_size, action_dim, device=device) # actions = torch.randn(batch_size, action_dim, device=device)
# Warmup run # # Warmup run
print("Warming up...") # print("Warming up...")
_ = critic_ensemble(obs_dict, actions) # _ = critic_ensemble(obs_dict, actions)
# Time the forward pass # # Time the forward pass
print(f"Running benchmark with {num_iterations} iterations...") # print(f"Running benchmark with {num_iterations} iterations...")
start_time = time.perf_counter() # start_time = time.perf_counter()
for _ in range(num_iterations): # for _ in range(num_iterations):
q_values = critic_ensemble(obs_dict, actions) # q_values = critic_ensemble(obs_dict, actions)
end_time = time.perf_counter() # end_time = time.perf_counter()
# Print results # # Print results
elapsed_time = end_time - start_time # elapsed_time = end_time - start_time
print(f"Total time: {elapsed_time:.4f} seconds") # print(f"Total time: {elapsed_time:.4f} seconds")
print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms") # print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size] # print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
# Verify that all critic heads produce different outputs # Verify that all critic heads produce different outputs
# This confirms each critic head is unique # This confirms each critic head is unique
@ -932,3 +840,11 @@ if __name__ == "__main__":
# for j in range(i + 1, num_critics): # for j in range(i + 1, num_critics):
# diff = torch.abs(q_values[i] - q_values[j]).mean().item() # diff = torch.abs(q_values[i] - q_values[j]).mean().item()
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}") # print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
import draccus
from lerobot.configs import parser
@parser.wrap()
def main(config: SACConfig):
policy = SACPolicy(config=config)
print("yolo")
main()