diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index f4a2bc4c..6df94761 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -20,6 +20,24 @@ from dataclasses import dataclass, field @dataclass class SACConfig: + input_shapes: dict[str, list[int]] = field( + default_factory=lambda: { + "observation.image": [3, 84, 84], + "observation.state": [4], + } + ) + output_shapes: dict[str, list[int]] = field( + default_factory=lambda: { + "action": [4], + } + ) + + # Normalization / Unnormalization + input_normalization_modes: dict[str, str] | None = None + output_normalization_modes: dict[str, str] = field( + default_factory=lambda: {"action": "min_max"}, + ) + discount = 0.99 temperature_init = 1.0 num_critics = 2 @@ -29,6 +47,9 @@ class SACConfig: temperature_lr = 3e-4 critic_target_update_weight = 0.005 utd_ratio = 2 + state_encoder_hidden_dim = 256 + latent_dim = 50 + target_entropy = None critic_network_kwargs = { "hidden_dims": [256, 256], "activate_final": True, diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 51258fac..87170d20 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -40,6 +40,8 @@ class SACPolicy( repo_url="https://github.com/huggingface/lerobot", tags=["robotics", "RL", "SAC"], ): + name = "sac" + def __init__( self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None ): @@ -71,7 +73,7 @@ class SACPolicy( self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics) self.critic_target = deepcopy(self.critic_ensemble) - self.actor_network = Policy( + self.actor = Policy( encoder=encoder, network=MLP(**config.actor_network_kwargs), action_dim=config.output_shapes["action"][0], @@ -91,14 +93,14 @@ class SACPolicy( "observation.state": deque(maxlen=1), "action": deque(maxlen=1), } - if self._use_image: + if "observation.image" in self.config.input_shapes: self._queues["observation.image"] = deque(maxlen=1) - if self._use_env_state: + if "observation.environment_state" in self.config.input_shapes: self._queues["observation.environment_state"] = deque(maxlen=1) @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: - actions, _ = self.actor_network(batch["observations"]) ### + actions, _ = self.actor(batch['observations']) def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: """Run the batch through the model and compute the loss. @@ -119,19 +121,18 @@ class SACPolicy( # perform image augmentation - # reward bias - # from HIL-SERL code base + # reward bias from HIL-SERL code base # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch # calculate critics loss # 1- compute actions from policy - action_preds, log_probs = self.actor_network(observations) + action_preds, log_probs = self.actor(observations) # 2- compute q targets q_targets = self.target_qs(next_observations, action_preds) # subsample critics to prevent overfitting if use high UTD (update to date) if self.config.num_subsample_critics is not None: indices = torch.randperm(self.config.num_critics) - indices = indices[: self.config.num_subsample_critics] + indices = indices[:self.config.num_subsample_critics] q_targets = q_targets[indices] # critics subsample size @@ -168,7 +169,8 @@ class SACPolicy( temperature = self.temperature() # 2- get actions (batch_size, action_dim) and log probs (batch_size,) - actions, log_probs = self.actor_network(observations) + actions, log_probs = self.actor(observations) \ + # 3- get q-value predictions with torch.no_grad(): q_preds = self.critic_ensemble(observations, actions, return_type="mean") @@ -209,21 +211,19 @@ class SACPolicy( class MLP(nn.Module): def __init__( self, - config: SACConfig, + hidden_dims: list[int], activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), activate_final: bool = False, dropout_rate: Optional[float] = None, ): super().__init__() - self.activate_final = config.activate_final + self.activate_final = activate_final layers = [] - - for i, size in enumerate(config.network_hidden_dims): - layers.append( - nn.Linear(config.network_hidden_dims[i - 1] if i > 0 else config.network_hidden_dims[0], size) - ) - - if i + 1 < len(config.network_hidden_dims) or activate_final: + + for i, size in enumerate(hidden_dims): + layers.append(nn.Linear(hidden_dims[i-1] if i > 0 else hidden_dims[0], size)) + + if i + 1 < len(hidden_dims) or activate_final: if dropout_rate is not None and dropout_rate > 0: layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.LayerNorm(size)) @@ -254,20 +254,20 @@ class Critic(nn.Module): self.network = network self.init_final = init_final self.activate_final = activate_final - + + # Find the last Linear layer's output dimension + for layer in reversed(network.net): + if isinstance(layer, nn.Linear): + out_features = layer.out_features + break + # Output layer if init_final is not None: - if self.activate_final: - self.output_layer = nn.Linear(network.net[-3].out_features, 1) - else: - self.output_layer = nn.Linear(network.net[-2].out_features, 1) + self.output_layer = nn.Linear(out_features, 1) nn.init.uniform_(self.output_layer.weight, -init_final, init_final) nn.init.uniform_(self.output_layer.bias, -init_final, init_final) else: - if self.activate_final: - self.output_layer = nn.Linear(network.net[-3].out_features, 1) - else: - self.output_layer = nn.Linear(network.net[-2].out_features, 1) + self.output_layer = nn.Linear(out_features, 1) orthogonal_init()(self.output_layer.weight) self.to(self.device) @@ -328,12 +328,15 @@ class Policy(nn.Module): self.tanh_squash_distribution = tanh_squash_distribution self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None self.activate_final = activate_final - + + # Find the last Linear layer's output dimension + for layer in reversed(network.net): + if isinstance(layer, nn.Linear): + out_features = layer.out_features + break + # Mean layer - if self.activate_final: - self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim) - else: - self.mean_layer = nn.Linear(network.net[-2].out_features, action_dim) + self.mean_layer = nn.Linear(out_features, action_dim) if init_final is not None: nn.init.uniform_(self.mean_layer.weight, -init_final, init_final) nn.init.uniform_(self.mean_layer.bias, -init_final, init_final) @@ -345,10 +348,7 @@ class Policy(nn.Module): if std_parameterization == "uniform": self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device)) else: - if self.activate_final: - self.std_layer = nn.Linear(network.net[-3].out_features, action_dim) - else: - self.std_layer = nn.Linear(network.net[-2].out_features, action_dim) + self.std_layer = nn.Linear(out_features, action_dim) if init_final is not None: nn.init.uniform_(self.std_layer.weight, -init_final, init_final) nn.init.uniform_(self.std_layer.bias, -init_final, init_final) @@ -571,7 +571,6 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): """Get the mode of the transformed distribution""" # The mode of a normal distribution is its mean mode = self.loc - # Apply transforms for transform in self.transforms: mode = transform(mode) @@ -634,10 +633,10 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): return entropy -def create_critic_ensemble(critic_class, num_critics: int, device: str = "cuda") -> nn.ModuleList: +def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList: """Creates an ensemble of critic networks""" - critics = nn.ModuleList([critic_class() for _ in range(num_critics)]) - return critics.to(device) + assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}" + return nn.ModuleList(critics).to(device) def orthogonal_init(): diff --git a/lerobot/configs/policy/sac_pusht_keypoints.yaml b/lerobot/configs/policy/sac_pusht_keypoints.yaml new file mode 100644 index 00000000..19af60d4 --- /dev/null +++ b/lerobot/configs/policy/sac_pusht_keypoints.yaml @@ -0,0 +1,89 @@ +# @package _global_ + +# Train with: +# +# python lerobot/scripts/train.py \ +# env=pusht \ +# +dataset=lerobot/pusht_keypoints + +seed: 1 +dataset_repo_id: lerobot/pusht_keypoints + +training: + offline_steps: 0 + + # Offline training dataloader + num_workers: 4 + + batch_size: 128 + grad_clip_norm: 10.0 + lr: 3e-4 + + eval_freq: 10000 + log_freq: 500 + save_freq: 50000 + + online_steps: 1000000 + online_rollout_n_episodes: 10 + online_rollout_batch_size: 10 + online_steps_between_rollouts: 1000 + online_sampling_ratio: 1.0 + online_env_seed: 10000 + online_buffer_capacity: 40000 + online_buffer_seed_size: 0 + do_online_rollout_async: false + + delta_timestamps: + observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" + observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" + action: "[i / ${fps} for i in range(${policy.horizon})]" + next.reward: "[i / ${fps} for i in range(${policy.horizon})]" + +policy: + name: sac + + pretrained_model_path: + + # Input / output structure. + n_action_repeats: 1 + horizon: 5 + n_action_steps: 5 + + input_shapes: + # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.environment_state: [16] + observation.state: ["${env.state_dim}"] + output_shapes: + action: ["${env.action_dim}"] + + # Normalization / Unnormalization + input_normalization_modes: + observation.environment_state: min_max + observation.state: min_max + output_normalization_modes: + action: min_max + + # Architecture / modeling. + # Neural networks. + # image_encoder_hidden_dim: 32 + discount: 0.99 + temperature_init: 1.0 + num_critics: 2 + num_subsample_critics: None + critic_lr: 3e-4 + actor_lr: 3e-4 + temperature_lr: 3e-4 + critic_target_update_weight: 0.005 + utd_ratio: 2 + + + # # Loss coefficients. + # reward_coeff: 0.5 + # expectile_weight: 0.9 + # value_coeff: 0.1 + # consistency_coeff: 20.0 + # advantage_scaling: 3.0 + # pi_coeff: 0.5 + # temporal_decay_coeff: 0.5 + # # Target model. + # target_model_momentum: 0.995