Handle multi optimizers

This commit is contained in:
AdilZouitine 2025-03-24 15:34:30 +00:00
parent b4ec6c8afb
commit 425e604f76
2 changed files with 299 additions and 8 deletions

View File

@ -14,8 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc import abc
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any
import draccus import draccus
import torch import torch
@ -44,7 +45,7 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
return "adam" return "adam"
@abc.abstractmethod @abc.abstractmethod
def build(self) -> torch.optim.Optimizer: def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
raise NotImplementedError raise NotImplementedError
@ -94,7 +95,73 @@ class SGDConfig(OptimizerConfig):
return torch.optim.SGD(params, **kwargs) return torch.optim.SGD(params, **kwargs)
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None: @OptimizerConfig.register_subclass("multi_adam")
@dataclass
class MultiAdamConfig(OptimizerConfig):
"""Configuration for multiple Adam optimizers with different parameter groups.
This creates a dictionary of Adam optimizers, each with its own hyperparameters.
Args:
lr: Default learning rate (used if not specified for a group)
weight_decay: Default weight decay (used if not specified for a group)
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
grad_clip_norm: Gradient clipping norm
"""
lr: float = 1e-3
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
"""Build multiple Adam optimizers.
Args:
params_dict: Dictionary mapping parameter group names to lists of parameters
The keys should match the keys in optimizer_groups
Returns:
Dictionary mapping parameter group names to their optimizers
"""
optimizers = {}
for name, params in params_dict.items():
# Get group-specific hyperparameters or use defaults
group_config = self.optimizer_groups.get(name, {})
# Create optimizer with merged parameters (defaults + group-specific)
optimizer_kwargs = {
"lr": group_config.get("lr", self.lr),
"betas": group_config.get("betas", (0.9, 0.999)),
"eps": group_config.get("eps", 1e-5),
"weight_decay": group_config.get("weight_decay", self.weight_decay),
}
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
return optimizers
def save_optimizer_state(optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path) -> None:
"""Save optimizer state to disk.
Args:
optimizer: Either a single optimizer or a dictionary of optimizers.
save_dir: Directory to save the optimizer state.
"""
if isinstance(optimizer, dict):
# Handle dictionary of optimizers
for name, opt in optimizer.items():
optimizer_dir = save_dir / name
optimizer_dir.mkdir(exist_ok=True, parents=True)
_save_single_optimizer_state(opt, optimizer_dir)
else:
# Handle single optimizer
_save_single_optimizer_state(optimizer, save_dir)
def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
"""Save a single optimizer's state to disk."""
state = optimizer.state_dict() state = optimizer.state_dict()
param_groups = state.pop("param_groups") param_groups = state.pop("param_groups")
flat_state = flatten_dict(state) flat_state = flatten_dict(state)
@ -102,11 +169,44 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS) write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer: def load_optimizer_state(
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
"""Load optimizer state from disk.
Args:
optimizer: Either a single optimizer or a dictionary of optimizers.
save_dir: Directory to load the optimizer state from.
Returns:
The updated optimizer(s) with loaded state.
"""
if isinstance(optimizer, dict):
# Handle dictionary of optimizers
loaded_optimizers = {}
for name, opt in optimizer.items():
optimizer_dir = save_dir / name
if optimizer_dir.exists():
loaded_optimizers[name] = _load_single_optimizer_state(opt, optimizer_dir)
else:
loaded_optimizers[name] = opt
return loaded_optimizers
else:
# Handle single optimizer
return _load_single_optimizer_state(optimizer, save_dir)
def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
"""Load a single optimizer's state from disk."""
current_state_dict = optimizer.state_dict() current_state_dict = optimizer.state_dict()
flat_state = load_file(save_dir / OPTIMIZER_STATE) flat_state = load_file(save_dir / OPTIMIZER_STATE)
state = unflatten_dict(flat_state) state = unflatten_dict(flat_state)
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
# Handle case where 'state' key might not exist (for newly created optimizers)
if "state" in state:
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
else:
loaded_state_dict = {"state": {}}
if "param_groups" in current_state_dict: if "param_groups" in current_state_dict:
param_groups = deserialize_json_into_object( param_groups = deserialize_json_into_object(

View File

@ -21,6 +21,7 @@ from lerobot.common.constants import (
from lerobot.common.optim.optimizers import ( from lerobot.common.optim.optimizers import (
AdamConfig, AdamConfig,
AdamWConfig, AdamWConfig,
MultiAdamConfig,
SGDConfig, SGDConfig,
load_optimizer_state, load_optimizer_state,
save_optimizer_state, save_optimizer_state,
@ -33,13 +34,21 @@ from lerobot.common.optim.optimizers import (
(AdamConfig, torch.optim.Adam), (AdamConfig, torch.optim.Adam),
(AdamWConfig, torch.optim.AdamW), (AdamWConfig, torch.optim.AdamW),
(SGDConfig, torch.optim.SGD), (SGDConfig, torch.optim.SGD),
(MultiAdamConfig, dict),
], ],
) )
def test_optimizer_build(config_cls, expected_class, model_params): def test_optimizer_build(config_cls, expected_class, model_params):
config = config_cls() config = config_cls()
optimizer = config.build(model_params) if config_cls == MultiAdamConfig:
assert isinstance(optimizer, expected_class) params_dict = {"default": model_params}
assert optimizer.defaults["lr"] == config.lr optimizer = config.build(params_dict)
assert isinstance(optimizer, expected_class)
assert isinstance(optimizer["default"], torch.optim.Adam)
assert optimizer["default"].defaults["lr"] == config.lr
else:
optimizer = config.build(model_params)
assert isinstance(optimizer, expected_class)
assert optimizer.defaults["lr"] == config.lr
def test_save_optimizer_state(optimizer, tmp_path): def test_save_optimizer_state(optimizer, tmp_path):
@ -54,3 +63,185 @@ def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path):
loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path) loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path)
torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict()) torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict())
@pytest.fixture
def base_params_dict():
return {
"actor": [torch.nn.Parameter(torch.randn(10, 10))],
"critic": [torch.nn.Parameter(torch.randn(5, 5))],
"temperature": [torch.nn.Parameter(torch.randn(3, 3))],
}
@pytest.mark.parametrize(
"config_params, expected_values",
[
# Test 1: Basic configuration with different learning rates
(
{
"lr": 1e-3,
"weight_decay": 1e-4,
"optimizer_groups": {
"actor": {"lr": 1e-4},
"critic": {"lr": 5e-4},
"temperature": {"lr": 2e-3},
},
},
{
"actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
"critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
},
),
# Test 2: Different weight decays and beta values
(
{
"lr": 1e-3,
"weight_decay": 1e-4,
"optimizer_groups": {
"actor": {"lr": 1e-4, "weight_decay": 1e-5},
"critic": {"lr": 5e-4, "weight_decay": 1e-6},
"temperature": {"lr": 2e-3, "betas": (0.95, 0.999)},
},
},
{
"actor": {"lr": 1e-4, "weight_decay": 1e-5, "betas": (0.9, 0.999)},
"critic": {"lr": 5e-4, "weight_decay": 1e-6, "betas": (0.9, 0.999)},
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.95, 0.999)},
},
),
# Test 3: Epsilon parameter customization
(
{
"lr": 1e-3,
"weight_decay": 1e-4,
"optimizer_groups": {
"actor": {"lr": 1e-4, "eps": 1e-6},
"critic": {"lr": 5e-4, "eps": 1e-7},
"temperature": {"lr": 2e-3, "eps": 1e-8},
},
},
{
"actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-6},
"critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-7},
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-8},
},
),
],
)
def test_multi_adam_configuration(base_params_dict, config_params, expected_values):
# Create config with the given parameters
config = MultiAdamConfig(**config_params)
optimizers = config.build(base_params_dict)
# Verify optimizer count and keys
assert len(optimizers) == len(expected_values)
assert set(optimizers.keys()) == set(expected_values.keys())
# Check that all optimizers are Adam instances
for opt in optimizers.values():
assert isinstance(opt, torch.optim.Adam)
# Verify hyperparameters for each optimizer
for name, expected in expected_values.items():
optimizer = optimizers[name]
for param, value in expected.items():
assert optimizer.defaults[param] == value
@pytest.fixture
def multi_optimizers(base_params_dict):
config = MultiAdamConfig(
lr=1e-3,
optimizer_groups={
"actor": {"lr": 1e-4},
"critic": {"lr": 5e-4},
"temperature": {"lr": 2e-3},
},
)
return config.build(base_params_dict)
def test_save_multi_optimizer_state(multi_optimizers, tmp_path):
# Save optimizer states
save_optimizer_state(multi_optimizers, tmp_path)
# Verify that directories were created for each optimizer
for name in multi_optimizers.keys():
assert (tmp_path / name).is_dir()
assert (tmp_path / name / OPTIMIZER_STATE).is_file()
assert (tmp_path / name / OPTIMIZER_PARAM_GROUPS).is_file()
def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers, tmp_path):
# Option 1: Add a minimal backward pass to populate optimizer states
for name, params in base_params_dict.items():
if name in multi_optimizers:
# Create a dummy loss and do backward
dummy_loss = params[0].sum()
dummy_loss.backward()
# Perform an optimization step
multi_optimizers[name].step()
# Zero gradients for next steps
multi_optimizers[name].zero_grad()
# Save optimizer states
save_optimizer_state(multi_optimizers, tmp_path)
# Create new optimizers with the same config
config = MultiAdamConfig(
lr=1e-3,
optimizer_groups={
"actor": {"lr": 1e-4},
"critic": {"lr": 5e-4},
"temperature": {"lr": 2e-3},
},
)
new_optimizers = config.build(base_params_dict)
# Load optimizer states
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
# Verify state dictionaries match
for name in multi_optimizers.keys():
torch.testing.assert_close(
multi_optimizers[name].state_dict(),
loaded_optimizers[name].state_dict()
)
def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
"""Test saving and loading optimizer states even when the state is empty (no backward pass)."""
# Create config and build optimizers
config = MultiAdamConfig(
lr=1e-3,
optimizer_groups={
"actor": {"lr": 1e-4},
"critic": {"lr": 5e-4},
"temperature": {"lr": 2e-3},
},
)
optimizers = config.build(base_params_dict)
# Save optimizer states without any backward pass (empty state)
save_optimizer_state(optimizers, tmp_path)
# Create new optimizers with the same config
new_optimizers = config.build(base_params_dict)
# Load optimizer states
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
# Verify hyperparameters match even with empty state
for name, optimizer in optimizers.items():
assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"]
assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"]
assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"]
# Verify state dictionaries match (they will be empty)
torch.testing.assert_close(
optimizer.state_dict()["param_groups"],
loaded_optimizers[name].state_dict()["param_groups"]
)