From 87da655eab47df09e821a5a097ef60057a6d286c Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Tue, 21 Jan 2025 09:51:12 +0000 Subject: [PATCH] Add type annotations and restructure SACConfig class fields --- .../common/policies/sac/configuration_sac.py | 94 ++++++++++++------- 1 file changed, 62 insertions(+), 32 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 3f5dae1c..97ba04b1 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -16,6 +16,7 @@ # limitations under the License. from dataclasses import dataclass, field +from typing import Any @dataclass @@ -26,6 +27,7 @@ class SACConfig: "observation.state": [4], } ) + output_shapes: dict[str, list[int]] = field( default_factory=lambda: { "action": [2], @@ -43,36 +45,64 @@ class SACConfig: output_normalization_modes: dict[str, str] = field( default_factory=lambda: {"action": "min_max"}, ) +from dataclasses import dataclass, field - shared_encoder = False - discount = 0.99 - temperature_init = 1.0 - num_critics = 2 - # num_critics = 8 - num_subsample_critics = None - # num_subsample_critics = 2 - # critic_lr = 1e-3 - critic_lr = 3e-4 - actor_lr = 3e-4 - temperature_lr = 3e-4 - critic_target_update_weight = 0.005 - # utd_ratio = 8 - utd_ratio = 1 # If you want enable utd_ratio, you need to set it to >1 - state_encoder_hidden_dim = 256 - latent_dim = 256 - target_entropy = None - # backup_entropy = False - use_backup_entropy = True - critic_network_kwargs = { - "hidden_dims": [256, 256], - "activate_final": True, - } - actor_network_kwargs = { - "hidden_dims": [256, 256], - "activate_final": True, - } - policy_kwargs = { - "use_tanh_squash": True, - "log_std_min": -5, - "log_std_max": 2, - } +@dataclass +class SACConfig: + input_shapes: dict[str, list[int]] = field( + default_factory=lambda: { + "observation.image": [3, 84, 84], + "observation.state": [4], + } + ) + output_shapes: dict[str, list[int]] = field( + default_factory=lambda: { + "action": [2], + } + ) + input_normalization_modes: dict[str, str] = field( + default_factory=lambda: { + "observation.image": "mean_std", + "observation.state": "min_max", + "observation.environment_state": "min_max", + } + ) + output_normalization_modes: dict[str, str] = field( + default_factory=lambda: {"action": "min_max"} + ) + + # Add type annotations for these fields: + image_encoder_hidden_dim: int = 32 + shared_encoder: bool = False + discount: float = 0.99 + temperature_init: float = 1.0 + num_critics: int = 2 + num_subsample_critics: int | None = None + critic_lr: float = 3e-4 + actor_lr: float = 3e-4 + temperature_lr: float = 3e-4 + critic_target_update_weight: float = 0.005 + utd_ratio: int = 1 # If you want enable utd_ratio, you need to set it to >1 + state_encoder_hidden_dim: int = 256 + latent_dim: int = 256 + target_entropy: float | None = None + use_backup_entropy: bool = True + critic_network_kwargs: dict[str, Any] = field( + default_factory=lambda: { + "hidden_dims": [256, 256], + "activate_final": True, + } + ) + actor_network_kwargs: dict[str, Any] = field( + default_factory=lambda: { + "hidden_dims": [256, 256], + "activate_final": True, + } + ) + policy_kwargs: dict[str, Any] = field( + default_factory=lambda: { + "use_tanh_squash": True, + "log_std_min": -5, + "log_std_max": 2, + } + )