diff --git a/lerobot/common/optim/optimizers.py b/lerobot/common/optim/optimizers.py index 0cf4124c..bfccd19a 100644 --- a/lerobot/common/optim/optimizers.py +++ b/lerobot/common/optim/optimizers.py @@ -14,8 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from pathlib import Path +from typing import Any import draccus import torch @@ -44,7 +45,7 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): return "adam" @abc.abstractmethod - def build(self) -> torch.optim.Optimizer: + def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]: raise NotImplementedError @@ -94,7 +95,73 @@ class SGDConfig(OptimizerConfig): return torch.optim.SGD(params, **kwargs) -def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None: +@OptimizerConfig.register_subclass("multi_adam") +@dataclass +class MultiAdamConfig(OptimizerConfig): + """Configuration for multiple Adam optimizers with different parameter groups. + + This creates a dictionary of Adam optimizers, each with its own hyperparameters. + + Args: + lr: Default learning rate (used if not specified for a group) + weight_decay: Default weight decay (used if not specified for a group) + optimizer_groups: Dictionary mapping parameter group names to their hyperparameters + grad_clip_norm: Gradient clipping norm + """ + lr: float = 1e-3 + weight_decay: float = 0.0 + grad_clip_norm: float = 10.0 + optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict) + + def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]: + """Build multiple Adam optimizers. + + Args: + params_dict: Dictionary mapping parameter group names to lists of parameters + The keys should match the keys in optimizer_groups + + Returns: + Dictionary mapping parameter group names to their optimizers + """ + optimizers = {} + + for name, params in params_dict.items(): + # Get group-specific hyperparameters or use defaults + group_config = self.optimizer_groups.get(name, {}) + + # Create optimizer with merged parameters (defaults + group-specific) + optimizer_kwargs = { + "lr": group_config.get("lr", self.lr), + "betas": group_config.get("betas", (0.9, 0.999)), + "eps": group_config.get("eps", 1e-5), + "weight_decay": group_config.get("weight_decay", self.weight_decay), + } + + optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs) + + return optimizers + + +def save_optimizer_state(optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path) -> None: + """Save optimizer state to disk. + + Args: + optimizer: Either a single optimizer or a dictionary of optimizers. + save_dir: Directory to save the optimizer state. + """ + if isinstance(optimizer, dict): + # Handle dictionary of optimizers + for name, opt in optimizer.items(): + optimizer_dir = save_dir / name + optimizer_dir.mkdir(exist_ok=True, parents=True) + _save_single_optimizer_state(opt, optimizer_dir) + else: + # Handle single optimizer + _save_single_optimizer_state(optimizer, save_dir) + + +def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None: + """Save a single optimizer's state to disk.""" state = optimizer.state_dict() param_groups = state.pop("param_groups") flat_state = flatten_dict(state) @@ -102,11 +169,44 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS) -def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer: +def load_optimizer_state( + optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path +) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]: + """Load optimizer state from disk. + + Args: + optimizer: Either a single optimizer or a dictionary of optimizers. + save_dir: Directory to load the optimizer state from. + + Returns: + The updated optimizer(s) with loaded state. + """ + if isinstance(optimizer, dict): + # Handle dictionary of optimizers + loaded_optimizers = {} + for name, opt in optimizer.items(): + optimizer_dir = save_dir / name + if optimizer_dir.exists(): + loaded_optimizers[name] = _load_single_optimizer_state(opt, optimizer_dir) + else: + loaded_optimizers[name] = opt + return loaded_optimizers + else: + # Handle single optimizer + return _load_single_optimizer_state(optimizer, save_dir) + + +def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer: + """Load a single optimizer's state from disk.""" current_state_dict = optimizer.state_dict() flat_state = load_file(save_dir / OPTIMIZER_STATE) state = unflatten_dict(flat_state) - loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}} + + # Handle case where 'state' key might not exist (for newly created optimizers) + if "state" in state: + loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}} + else: + loaded_state_dict = {"state": {}} if "param_groups" in current_state_dict: param_groups = deserialize_json_into_object( diff --git a/tests/optim/test_optimizers.py b/tests/optim/test_optimizers.py index 997e14fe..8f431ce5 100644 --- a/tests/optim/test_optimizers.py +++ b/tests/optim/test_optimizers.py @@ -21,6 +21,7 @@ from lerobot.common.constants import ( from lerobot.common.optim.optimizers import ( AdamConfig, AdamWConfig, + MultiAdamConfig, SGDConfig, load_optimizer_state, save_optimizer_state, @@ -33,13 +34,21 @@ from lerobot.common.optim.optimizers import ( (AdamConfig, torch.optim.Adam), (AdamWConfig, torch.optim.AdamW), (SGDConfig, torch.optim.SGD), + (MultiAdamConfig, dict), ], ) def test_optimizer_build(config_cls, expected_class, model_params): config = config_cls() - optimizer = config.build(model_params) - assert isinstance(optimizer, expected_class) - assert optimizer.defaults["lr"] == config.lr + if config_cls == MultiAdamConfig: + params_dict = {"default": model_params} + optimizer = config.build(params_dict) + assert isinstance(optimizer, expected_class) + assert isinstance(optimizer["default"], torch.optim.Adam) + assert optimizer["default"].defaults["lr"] == config.lr + else: + optimizer = config.build(model_params) + assert isinstance(optimizer, expected_class) + assert optimizer.defaults["lr"] == config.lr def test_save_optimizer_state(optimizer, tmp_path): @@ -54,3 +63,185 @@ def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path): loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path) torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict()) + + +@pytest.fixture +def base_params_dict(): + return { + "actor": [torch.nn.Parameter(torch.randn(10, 10))], + "critic": [torch.nn.Parameter(torch.randn(5, 5))], + "temperature": [torch.nn.Parameter(torch.randn(3, 3))], + } + + +@pytest.mark.parametrize( + "config_params, expected_values", + [ + # Test 1: Basic configuration with different learning rates + ( + { + "lr": 1e-3, + "weight_decay": 1e-4, + "optimizer_groups": { + "actor": {"lr": 1e-4}, + "critic": {"lr": 5e-4}, + "temperature": {"lr": 2e-3}, + }, + }, + { + "actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, + "critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, + "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, + }, + ), + # Test 2: Different weight decays and beta values + ( + { + "lr": 1e-3, + "weight_decay": 1e-4, + "optimizer_groups": { + "actor": {"lr": 1e-4, "weight_decay": 1e-5}, + "critic": {"lr": 5e-4, "weight_decay": 1e-6}, + "temperature": {"lr": 2e-3, "betas": (0.95, 0.999)}, + }, + }, + { + "actor": {"lr": 1e-4, "weight_decay": 1e-5, "betas": (0.9, 0.999)}, + "critic": {"lr": 5e-4, "weight_decay": 1e-6, "betas": (0.9, 0.999)}, + "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.95, 0.999)}, + }, + ), + # Test 3: Epsilon parameter customization + ( + { + "lr": 1e-3, + "weight_decay": 1e-4, + "optimizer_groups": { + "actor": {"lr": 1e-4, "eps": 1e-6}, + "critic": {"lr": 5e-4, "eps": 1e-7}, + "temperature": {"lr": 2e-3, "eps": 1e-8}, + }, + }, + { + "actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-6}, + "critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-7}, + "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-8}, + }, + ), + ], +) +def test_multi_adam_configuration(base_params_dict, config_params, expected_values): + # Create config with the given parameters + config = MultiAdamConfig(**config_params) + optimizers = config.build(base_params_dict) + + # Verify optimizer count and keys + assert len(optimizers) == len(expected_values) + assert set(optimizers.keys()) == set(expected_values.keys()) + + # Check that all optimizers are Adam instances + for opt in optimizers.values(): + assert isinstance(opt, torch.optim.Adam) + + # Verify hyperparameters for each optimizer + for name, expected in expected_values.items(): + optimizer = optimizers[name] + for param, value in expected.items(): + assert optimizer.defaults[param] == value + + +@pytest.fixture +def multi_optimizers(base_params_dict): + config = MultiAdamConfig( + lr=1e-3, + optimizer_groups={ + "actor": {"lr": 1e-4}, + "critic": {"lr": 5e-4}, + "temperature": {"lr": 2e-3}, + }, + ) + return config.build(base_params_dict) + + +def test_save_multi_optimizer_state(multi_optimizers, tmp_path): + # Save optimizer states + save_optimizer_state(multi_optimizers, tmp_path) + + # Verify that directories were created for each optimizer + for name in multi_optimizers.keys(): + assert (tmp_path / name).is_dir() + assert (tmp_path / name / OPTIMIZER_STATE).is_file() + assert (tmp_path / name / OPTIMIZER_PARAM_GROUPS).is_file() + + +def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers, tmp_path): + # Option 1: Add a minimal backward pass to populate optimizer states + for name, params in base_params_dict.items(): + if name in multi_optimizers: + # Create a dummy loss and do backward + dummy_loss = params[0].sum() + dummy_loss.backward() + # Perform an optimization step + multi_optimizers[name].step() + # Zero gradients for next steps + multi_optimizers[name].zero_grad() + + # Save optimizer states + save_optimizer_state(multi_optimizers, tmp_path) + + # Create new optimizers with the same config + config = MultiAdamConfig( + lr=1e-3, + optimizer_groups={ + "actor": {"lr": 1e-4}, + "critic": {"lr": 5e-4}, + "temperature": {"lr": 2e-3}, + }, + ) + new_optimizers = config.build(base_params_dict) + + # Load optimizer states + loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path) + + # Verify state dictionaries match + for name in multi_optimizers.keys(): + torch.testing.assert_close( + multi_optimizers[name].state_dict(), + loaded_optimizers[name].state_dict() + ) + + +def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path): + """Test saving and loading optimizer states even when the state is empty (no backward pass).""" + # Create config and build optimizers + config = MultiAdamConfig( + lr=1e-3, + optimizer_groups={ + "actor": {"lr": 1e-4}, + "critic": {"lr": 5e-4}, + "temperature": {"lr": 2e-3}, + }, + ) + optimizers = config.build(base_params_dict) + + # Save optimizer states without any backward pass (empty state) + save_optimizer_state(optimizers, tmp_path) + + # Create new optimizers with the same config + new_optimizers = config.build(base_params_dict) + + # Load optimizer states + loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path) + + # Verify hyperparameters match even with empty state + for name, optimizer in optimizers.items(): + assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"] + assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"] + assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"] + + # Verify state dictionaries match (they will be empty) + torch.testing.assert_close( + optimizer.state_dict()["param_groups"], + loaded_optimizers[name].state_dict()["param_groups"] + ) +