diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index fed799d9..a843ac1a 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -55,10 +55,9 @@ def get_policy_class(name: str) -> PreTrainedPolicy: return PI0Policy elif name == "sac": - from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.sac.modeling_sac import SACPolicy - return SACPolicy, SACConfig + return SACPolicy else: raise NotImplementedError(f"Policy with name {name} is not implemented.") diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index 012c854d..433387f9 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -79,28 +79,38 @@ def create_stats_buffers( ) # TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch) - if stats: - if isinstance(stats[key]["mean"], np.ndarray): - if norm_mode is NormalizationMode.MEAN_STD: + if stats and key in stats: + if norm_mode is NormalizationMode.MEAN_STD: + if "mean" not in stats[key] or "std" not in stats[key]: + raise ValueError(f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization") + + if isinstance(stats[key]["mean"], np.ndarray): buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32) buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32) - elif norm_mode is NormalizationMode.MIN_MAX: - buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) - buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) - elif isinstance(stats[key]["mean"], torch.Tensor): - # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated - # tensors anywhere (for example, when we use the same stats for normalization and - # unnormalization). See the logic here - # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. - if norm_mode is NormalizationMode.MEAN_STD: + elif isinstance(stats[key]["mean"], torch.Tensor): + # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated + # tensors anywhere (for example, when we use the same stats for normalization and + # unnormalization). See the logic here + # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32) buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32) - elif norm_mode is NormalizationMode.MIN_MAX: + else: + type_ = type(stats[key]["mean"]) + raise ValueError(f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead.") + + elif norm_mode is NormalizationMode.MIN_MAX: + if "min" not in stats[key] or "max" not in stats[key]: + raise ValueError(f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization") + + if isinstance(stats[key]["min"], np.ndarray): + buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) + buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) + elif isinstance(stats[key]["min"], torch.Tensor): buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32) buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32) - else: - type_ = type(stats[key]["mean"]) - raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.") + else: + type_ = type(stats[key]["min"]) + raise ValueError(f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead.") stats_buffers[key] = buffer return stats_buffers @@ -155,6 +165,7 @@ class Normalize(nn.Module): for key, ft in self.features.items(): if key not in batch: # FIXME(aliberts, rcadene): This might lead to silent fail! + # NOTE: (azouitine) This continues help us for instantiation SACPolicy continue norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 8d98ed40..b34e5f60 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -18,57 +18,82 @@ from dataclasses import dataclass, field from typing import Any +from lerobot.common.optim.optimizers import MultiAdamConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode + +@PreTrainedConfig.register_subclass("sac") @dataclass -class SACConfig: - input_shapes: dict[str, list[int]] = field( +class SACConfig(PreTrainedConfig): + """Configuration class for Soft Actor-Critic (SAC) policy. + + Args: + n_obs_steps: Number of environment steps worth of observations to pass to the policy. + normalization_mapping: Mapping from feature types to normalization modes. + camera_number: Number of cameras to use. + storage_device: Device to use for storage. + vision_encoder_name: Name of the vision encoder to use. + freeze_vision_encoder: Whether to freeze the vision encoder. + image_encoder_hidden_dim: Hidden dimension for the image encoder. + shared_encoder: Whether to use a shared encoder. + discount: Discount factor for the RL algorithm. + temperature_init: Initial temperature for entropy regularization. + num_critics: Number of critic networks. + num_subsample_critics: Number of critics to subsample. + critic_lr: Learning rate for critic networks. + actor_lr: Learning rate for actor network. + temperature_lr: Learning rate for temperature parameter. + critic_target_update_weight: Weight for soft target updates. + utd_ratio: Update-to-data ratio (>1 to enable). + state_encoder_hidden_dim: Hidden dimension for state encoder. + latent_dim: Dimension of latent representation. + target_entropy: Target entropy for automatic temperature tuning. + use_backup_entropy: Whether to use backup entropy. + grad_clip_norm: Gradient clipping norm. + critic_network_kwargs: Additional arguments for critic networks. + actor_network_kwargs: Additional arguments for actor network. + policy_kwargs: Additional arguments for policy. + """ + + # Input / output structure + n_obs_steps: int = 1 + + normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { - "observation.image": [3, 84, 84], - "observation.state": [4], + "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MIN_MAX, + "ENV": NormalizationMode.MIN_MAX, + "ACTION": NormalizationMode.MIN_MAX, } ) - output_shapes: dict[str, list[int]] = field( - default_factory=lambda: { - "action": [2], - } - ) - input_normalization_modes: dict[str, str] = field( - default_factory=lambda: { - "observation.image": "mean_std", - "observation.state": "min_max", - "observation.environment_state": "min_max", - } - ) - input_normalization_params: dict[str, dict[str, list[float]]] = field( + dataset_stats: dict[str, dict[str, list[float]]] = field( default_factory=lambda: { "observation.image": { - "mean": [[0.485, 0.456, 0.406]], - "std": [[0.229, 0.224, 0.225]], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + }, + "observation.state": { + "min": [0.0, 0.0], + "max": [1.0, 1.0], + }, + "action": { + "min": [0.0, 0.0, 0.0], + "max": [1.0, 1.0, 1.0], }, - "observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]}, } ) - output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"}) - output_normalization_params: dict[str, dict[str, list[float]]] = field( - default_factory=lambda: { - "action": {"min": [-1, -1], "max": [1, 1]}, - } - ) - # TODO: Move it outside of the config - actor_learner_config: dict[str, str | int] = field( - default_factory=lambda: { - "learner_host": "127.0.0.1", - "learner_port": 50051, - } - ) - camera_number: int = 1 + # Architecture specifics + camera_number: int = 1 storage_device: str = "cpu" - # Add type annotations for these fields: - vision_encoder_name: str | None = field(default="helper2424/resnet10") + # Set to "helper2424/resnet10" for hil serl + vision_encoder_name: str | None = None freeze_vision_encoder: bool = True image_encoder_hidden_dim: int = 32 shared_encoder: bool = True + + # SAC algorithm parameters discount: float = 0.99 temperature_init: float = 1.0 num_critics: int = 2 @@ -83,6 +108,8 @@ class SACConfig: target_entropy: float | None = None use_backup_entropy: bool = True grad_clip_norm: float = 40.0 + + # Network configuration critic_network_kwargs: dict[str, Any] = field( default_factory=lambda: { "hidden_dims": [256, 256], @@ -104,3 +131,52 @@ class SACConfig: "init_final": 0.05, } ) + + # Deprecated, kept for backward compatibility + actor_learner_config: dict[str, str | int] = field( + default_factory=lambda: { + "learner_host": "127.0.0.1", + "learner_port": 50051, + } + ) + + def __post_init__(self): + super().__post_init__() + # Any validation specific to SAC configuration + + def get_optimizer_preset(self) -> MultiAdamConfig: + return MultiAdamConfig( + weight_decay=0.0, + optimizer_groups={ + "actor": {"lr": self.actor_lr}, + "critic": {"lr": self.critic_lr}, + "temperature": {"lr": self.temperature_lr}, + }, + ) + + def get_scheduler_preset(self) -> None: + return None + + def validate_features(self) -> None: + # TODO: Maybe we should remove this raise? + if len(self.image_features) == 0: + raise ValueError("You must provide at least one image among the inputs.") + + @property + def observation_delta_indices(self) -> list: + return list(range(1 - self.n_obs_steps, 1)) + + @property + def action_delta_indices(self) -> list: + return [0] # SAC typically predicts one action at a time + + @property + def reward_delta_indices(self) -> None: + return None + +if __name__ == "__main__": + import draccus + config = SACConfig() + draccus.set_config_type("json") + draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), ) + diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 7cf9b8d6..083ef567 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -29,18 +29,17 @@ import torch.nn.functional as F # noqa: N812 from huggingface_hub import PyTorchModelHubMixin from torch import Tensor +from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.utils import get_device_from_parameters class SACPolicy( - nn.Module, - PyTorchModelHubMixin, - library_name="lerobot", - repo_url="https://github.com/huggingface/lerobot", - tags=["robotics", "RL", "SAC"], + PreTrainedPolicy, ): + + config_class = SACConfig name = "sac" def __init__( @@ -48,35 +47,33 @@ class SACPolicy( config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): - super().__init__() - - if config is None: - config = SACConfig() + super().__init__(config) + config.validate_features() self.config = config - if config.input_normalization_modes is not None: + if config.dataset_stats is not None: input_normalization_params = _convert_normalization_params_to_tensor( - config.input_normalization_params + config.dataset_stats ) self.normalize_inputs = Normalize( - config.input_shapes, - config.input_normalization_modes, + config.input_features, + config.normalization_mapping, input_normalization_params, ) else: self.normalize_inputs = nn.Identity() output_normalization_params = _convert_normalization_params_to_tensor( - config.output_normalization_params + config.dataset_stats ) # HACK: This is hacky and should be removed - dataset_stats = dataset_stats or output_normalization_params + dataset_stats = dataset_stats or output_normalization_params self.normalize_targets = Normalize( - config.output_shapes, config.output_normalization_modes, dataset_stats + config.output_features, config.normalization_mapping, dataset_stats ) self.unnormalize_outputs = Unnormalize( - config.output_shapes, config.output_normalization_modes, dataset_stats + config.output_features, config.normalization_mapping, dataset_stats ) # NOTE: For images the encoder should be shared between the actor and critic @@ -90,7 +87,7 @@ class SACPolicy( # Create a list of critic heads critic_heads = [ CriticHead( - input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], + input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0], **config.critic_network_kwargs, ) for _ in range(config.num_critics) @@ -105,7 +102,7 @@ class SACPolicy( # Create target critic heads as deepcopies of the original critic heads target_critic_heads = [ CriticHead( - input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], + input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0], **config.critic_network_kwargs, ) for _ in range(config.num_critics) @@ -125,12 +122,12 @@ class SACPolicy( self.actor = Policy( encoder=encoder_actor, network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs), - action_dim=config.output_shapes["action"][0], + action_dim=config.output_features["action"].shape[0], encoder_is_shared=config.shared_encoder, **config.policy_kwargs, ) if config.target_entropy is None: - config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2) + config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2) # TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise @@ -140,102 +137,13 @@ class SACPolicy( self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)])) self.temperature = self.log_alpha.exp().item() - def _save_pretrained(self, save_directory): - """Custom save method to handle TensorDict properly""" - import json - import os - from dataclasses import asdict - from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE - from safetensors.torch import save_model - - save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE)) - - # Save config - config_dict = asdict(self.config) - with open(os.path.join(save_directory, CONFIG_NAME), "w") as f: - json.dump(config_dict, f, indent=2) - print(f"Saved config to {os.path.join(save_directory, CONFIG_NAME)}") - - @classmethod - def _from_pretrained( - cls, - *, - model_id: str, - revision: Optional[str], - cache_dir: Optional[Union[str, Path]], - force_download: bool, - proxies: Optional[Dict], - resume_download: Optional[bool], - local_files_only: bool, - token: Optional[Union[str, bool]], - map_location: str = "cpu", - strict: bool = False, - **model_kwargs, - ) -> "SACPolicy": - """Custom load method to handle loading SAC policy from saved files""" - import json - import os - from pathlib import Path - - from huggingface_hub import hf_hub_download - from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE - from safetensors.torch import load_model - - from lerobot.common.policies.sac.configuration_sac import SACConfig - - # Check if model_id is a local path or a hub model ID - if os.path.isdir(model_id): - model_path = Path(model_id) - safetensors_file = os.path.join(model_path, SAFETENSORS_SINGLE_FILE) - config_file = os.path.join(model_path, CONFIG_NAME) - else: - # Download the safetensors file from the hub - safetensors_file = hf_hub_download( - repo_id=model_id, - filename=SAFETENSORS_SINGLE_FILE, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - token=token, - local_files_only=local_files_only, - ) - # Download the config file - try: - config_file = hf_hub_download( - repo_id=model_id, - filename=CONFIG_NAME, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - token=token, - local_files_only=local_files_only, - ) - except Exception: - config_file = None - - # Load or create config - if config_file and os.path.exists(config_file): - # Load config from file - with open(config_file) as f: - config_dict = json.load(f) - config = SACConfig(**config_dict) - else: - # Use the provided config or create a default one - config = model_kwargs.get("config", SACConfig()) - - # Create a new instance with the loaded config - model = cls(config=config) - - # Load state dict from safetensors file - if os.path.exists(safetensors_file): - load_model(model, filename=safetensors_file, device=map_location) - - return model + def get_optim_params(self) -> dict: + return { + "actor": self.actor.parameters_to_optimize, + "critic": self.critic_ensemble.parameters_to_optimize, + "temperature": self.log_alpha, + } def reset(self): """Reset the policy""" @@ -667,7 +575,7 @@ class SACObservationEncoder(nn.Module): self.parameters_to_optimize = [] self.aggregation_size: int = 0 - if any("observation.image" in key for key in config.input_shapes): + if any("observation.image" in key for key in config.input_features): self.camera_number = config.camera_number if self.config.vision_encoder_name is not None: @@ -682,12 +590,12 @@ class SACObservationEncoder(nn.Module): freeze_image_encoder(self.image_enc_layers) else: self.parameters_to_optimize += list(self.image_enc_layers.parameters()) - self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")] - if "observation.state" in config.input_shapes: + if "observation.state" in config.input_features: self.state_enc_layers = nn.Sequential( nn.Linear( - in_features=config.input_shapes["observation.state"][0], + in_features=config.input_features["observation.state"].shape[0], out_features=config.latent_dim, ), nn.LayerNorm(normalized_shape=config.latent_dim), @@ -697,10 +605,10 @@ class SACObservationEncoder(nn.Module): self.parameters_to_optimize += list(self.state_enc_layers.parameters()) - if "observation.environment_state" in config.input_shapes: + if "observation.environment_state" in config.input_features: self.env_state_enc_layers = nn.Sequential( nn.Linear( - in_features=config.input_shapes["observation.environment_state"][0], + in_features=config.input_features["observation.environment_state"].shape[0], out_features=config.latent_dim, ), nn.LayerNorm(normalized_shape=config.latent_dim), @@ -727,9 +635,9 @@ class SACObservationEncoder(nn.Module): embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys)) feat.extend(embeddings_chunks) - if "observation.environment_state" in self.config.input_shapes: + if "observation.environment_state" in self.config.input_features: feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) - if "observation.state" in self.config.input_shapes: + if "observation.state" in self.config.input_features: feat.append(self.state_enc_layers(obs_dict["observation.state"])) features = torch.cat(tensors=feat, dim=-1) @@ -744,11 +652,11 @@ class SACObservationEncoder(nn.Module): class DefaultImageEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config: SACConfig): super().__init__() self.image_enc_layers = nn.Sequential( nn.Conv2d( - in_channels=config.input_shapes["observation.image"][0], + in_channels=config.input_features["observation.image"].shape[0], out_channels=config.image_encoder_hidden_dim, kernel_size=7, stride=2, @@ -776,7 +684,7 @@ class DefaultImageEncoder(nn.Module): ), nn.ReLU(), ) - dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) + dummy_batch = torch.zeros(1, *config.input_features["observation.image"].shape) with torch.inference_mode(): self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:] self.image_enc_layers.extend( @@ -793,7 +701,7 @@ class DefaultImageEncoder(nn.Module): class PretrainedImageEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config: SACConfig): super().__init__() self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config) @@ -803,7 +711,7 @@ class PretrainedImageEncoder(nn.Module): nn.Tanh(), ) - def _load_pretrained_vision_encoder(self, config): + def _load_pretrained_vision_encoder(self, config: SACConfig): """Set up CNN encoder""" from transformers import AutoModel @@ -857,73 +765,73 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict: if __name__ == "__main__": - # Benchmark the CriticEnsemble performance - import time + # # Benchmark the CriticEnsemble performance + # import time - # Configuration - num_critics = 10 - batch_size = 32 - action_dim = 7 - obs_dim = 64 - hidden_dims = [256, 256] - num_iterations = 100 + # # Configuration + # num_critics = 10 + # batch_size = 32 + # action_dim = 7 + # obs_dim = 64 + # hidden_dims = [256, 256] + # num_iterations = 100 - print("Creating test environment...") + # print("Creating test environment...") - # Create a simple dummy encoder - class DummyEncoder(nn.Module): - def __init__(self): - super().__init__() - self.output_dim = obs_dim - self.parameters_to_optimize = [] + # # Create a simple dummy encoder + # class DummyEncoder(nn.Module): + # def __init__(self): + # super().__init__() + # self.output_dim = obs_dim + # self.parameters_to_optimize = [] - def forward(self, obs): - # Just return a random tensor of the right shape - # In practice, this would encode the observations - return torch.randn(batch_size, obs_dim, device=device) + # def forward(self, obs): + # # Just return a random tensor of the right shape + # # In practice, this would encode the observations + # return torch.randn(batch_size, obs_dim, device=device) - # Create critic heads - print(f"Creating {num_critics} critic heads...") - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - critic_heads = [ - CriticHead( - input_dim=obs_dim + action_dim, - hidden_dims=hidden_dims, - ).to(device) - for _ in range(num_critics) - ] + # # Create critic heads + # print(f"Creating {num_critics} critic heads...") + # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # critic_heads = [ + # CriticHead( + # input_dim=obs_dim + action_dim, + # hidden_dims=hidden_dims, + # ).to(device) + # for _ in range(num_critics) + # ] - # Create the critic ensemble - print("Creating CriticEnsemble...") - critic_ensemble = CriticEnsemble( - encoder=DummyEncoder().to(device), - ensemble=critic_heads, - output_normalization=nn.Identity(), - ).to(device) + # # Create the critic ensemble + # print("Creating CriticEnsemble...") + # critic_ensemble = CriticEnsemble( + # encoder=DummyEncoder().to(device), + # ensemble=critic_heads, + # output_normalization=nn.Identity(), + # ).to(device) - # Create random input data - print("Creating input data...") - obs_dict = { - "observation.state": torch.randn(batch_size, obs_dim, device=device), - } - actions = torch.randn(batch_size, action_dim, device=device) + # # Create random input data + # print("Creating input data...") + # obs_dict = { + # "observation.state": torch.randn(batch_size, obs_dim, device=device), + # } + # actions = torch.randn(batch_size, action_dim, device=device) - # Warmup run - print("Warming up...") - _ = critic_ensemble(obs_dict, actions) + # # Warmup run + # print("Warming up...") + # _ = critic_ensemble(obs_dict, actions) - # Time the forward pass - print(f"Running benchmark with {num_iterations} iterations...") - start_time = time.perf_counter() - for _ in range(num_iterations): - q_values = critic_ensemble(obs_dict, actions) - end_time = time.perf_counter() + # # Time the forward pass + # print(f"Running benchmark with {num_iterations} iterations...") + # start_time = time.perf_counter() + # for _ in range(num_iterations): + # q_values = critic_ensemble(obs_dict, actions) + # end_time = time.perf_counter() - # Print results - elapsed_time = end_time - start_time - print(f"Total time: {elapsed_time:.4f} seconds") - print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms") - print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size] + # # Print results + # elapsed_time = end_time - start_time + # print(f"Total time: {elapsed_time:.4f} seconds") + # print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms") + # print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size] # Verify that all critic heads produce different outputs # This confirms each critic head is unique @@ -932,3 +840,11 @@ if __name__ == "__main__": # for j in range(i + 1, num_critics): # diff = torch.abs(q_values[i] - q_values[j]).mean().item() # print(f"Mean difference between critic {i} and {j}: {diff:.6f}") + import draccus + + from lerobot.configs import parser + @parser.wrap() + def main(config: SACConfig): + policy = SACPolicy(config=config) + print("yolo") + main() \ No newline at end of file