Handle multi optimizers
This commit is contained in:
parent
b4ec6c8afb
commit
425e604f76
|
@ -14,8 +14,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import abc
|
import abc
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
import torch
|
import torch
|
||||||
|
@ -44,7 +45,7 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||||
return "adam"
|
return "adam"
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def build(self) -> torch.optim.Optimizer:
|
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@ -94,7 +95,73 @@ class SGDConfig(OptimizerConfig):
|
||||||
return torch.optim.SGD(params, **kwargs)
|
return torch.optim.SGD(params, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
@OptimizerConfig.register_subclass("multi_adam")
|
||||||
|
@dataclass
|
||||||
|
class MultiAdamConfig(OptimizerConfig):
|
||||||
|
"""Configuration for multiple Adam optimizers with different parameter groups.
|
||||||
|
|
||||||
|
This creates a dictionary of Adam optimizers, each with its own hyperparameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lr: Default learning rate (used if not specified for a group)
|
||||||
|
weight_decay: Default weight decay (used if not specified for a group)
|
||||||
|
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
|
||||||
|
grad_clip_norm: Gradient clipping norm
|
||||||
|
"""
|
||||||
|
lr: float = 1e-3
|
||||||
|
weight_decay: float = 0.0
|
||||||
|
grad_clip_norm: float = 10.0
|
||||||
|
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
|
||||||
|
"""Build multiple Adam optimizers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params_dict: Dictionary mapping parameter group names to lists of parameters
|
||||||
|
The keys should match the keys in optimizer_groups
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping parameter group names to their optimizers
|
||||||
|
"""
|
||||||
|
optimizers = {}
|
||||||
|
|
||||||
|
for name, params in params_dict.items():
|
||||||
|
# Get group-specific hyperparameters or use defaults
|
||||||
|
group_config = self.optimizer_groups.get(name, {})
|
||||||
|
|
||||||
|
# Create optimizer with merged parameters (defaults + group-specific)
|
||||||
|
optimizer_kwargs = {
|
||||||
|
"lr": group_config.get("lr", self.lr),
|
||||||
|
"betas": group_config.get("betas", (0.9, 0.999)),
|
||||||
|
"eps": group_config.get("eps", 1e-5),
|
||||||
|
"weight_decay": group_config.get("weight_decay", self.weight_decay),
|
||||||
|
}
|
||||||
|
|
||||||
|
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
|
||||||
|
|
||||||
|
return optimizers
|
||||||
|
|
||||||
|
|
||||||
|
def save_optimizer_state(optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path) -> None:
|
||||||
|
"""Save optimizer state to disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||||
|
save_dir: Directory to save the optimizer state.
|
||||||
|
"""
|
||||||
|
if isinstance(optimizer, dict):
|
||||||
|
# Handle dictionary of optimizers
|
||||||
|
for name, opt in optimizer.items():
|
||||||
|
optimizer_dir = save_dir / name
|
||||||
|
optimizer_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
_save_single_optimizer_state(opt, optimizer_dir)
|
||||||
|
else:
|
||||||
|
# Handle single optimizer
|
||||||
|
_save_single_optimizer_state(optimizer, save_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
||||||
|
"""Save a single optimizer's state to disk."""
|
||||||
state = optimizer.state_dict()
|
state = optimizer.state_dict()
|
||||||
param_groups = state.pop("param_groups")
|
param_groups = state.pop("param_groups")
|
||||||
flat_state = flatten_dict(state)
|
flat_state = flatten_dict(state)
|
||||||
|
@ -102,11 +169,44 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No
|
||||||
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
||||||
|
|
||||||
|
|
||||||
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
def load_optimizer_state(
|
||||||
|
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
||||||
|
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||||
|
"""Load optimizer state from disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||||
|
save_dir: Directory to load the optimizer state from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated optimizer(s) with loaded state.
|
||||||
|
"""
|
||||||
|
if isinstance(optimizer, dict):
|
||||||
|
# Handle dictionary of optimizers
|
||||||
|
loaded_optimizers = {}
|
||||||
|
for name, opt in optimizer.items():
|
||||||
|
optimizer_dir = save_dir / name
|
||||||
|
if optimizer_dir.exists():
|
||||||
|
loaded_optimizers[name] = _load_single_optimizer_state(opt, optimizer_dir)
|
||||||
|
else:
|
||||||
|
loaded_optimizers[name] = opt
|
||||||
|
return loaded_optimizers
|
||||||
|
else:
|
||||||
|
# Handle single optimizer
|
||||||
|
return _load_single_optimizer_state(optimizer, save_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||||
|
"""Load a single optimizer's state from disk."""
|
||||||
current_state_dict = optimizer.state_dict()
|
current_state_dict = optimizer.state_dict()
|
||||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||||
state = unflatten_dict(flat_state)
|
state = unflatten_dict(flat_state)
|
||||||
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
|
||||||
|
# Handle case where 'state' key might not exist (for newly created optimizers)
|
||||||
|
if "state" in state:
|
||||||
|
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
||||||
|
else:
|
||||||
|
loaded_state_dict = {"state": {}}
|
||||||
|
|
||||||
if "param_groups" in current_state_dict:
|
if "param_groups" in current_state_dict:
|
||||||
param_groups = deserialize_json_into_object(
|
param_groups = deserialize_json_into_object(
|
||||||
|
|
|
@ -21,6 +21,7 @@ from lerobot.common.constants import (
|
||||||
from lerobot.common.optim.optimizers import (
|
from lerobot.common.optim.optimizers import (
|
||||||
AdamConfig,
|
AdamConfig,
|
||||||
AdamWConfig,
|
AdamWConfig,
|
||||||
|
MultiAdamConfig,
|
||||||
SGDConfig,
|
SGDConfig,
|
||||||
load_optimizer_state,
|
load_optimizer_state,
|
||||||
save_optimizer_state,
|
save_optimizer_state,
|
||||||
|
@ -33,13 +34,21 @@ from lerobot.common.optim.optimizers import (
|
||||||
(AdamConfig, torch.optim.Adam),
|
(AdamConfig, torch.optim.Adam),
|
||||||
(AdamWConfig, torch.optim.AdamW),
|
(AdamWConfig, torch.optim.AdamW),
|
||||||
(SGDConfig, torch.optim.SGD),
|
(SGDConfig, torch.optim.SGD),
|
||||||
|
(MultiAdamConfig, dict),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_optimizer_build(config_cls, expected_class, model_params):
|
def test_optimizer_build(config_cls, expected_class, model_params):
|
||||||
config = config_cls()
|
config = config_cls()
|
||||||
optimizer = config.build(model_params)
|
if config_cls == MultiAdamConfig:
|
||||||
assert isinstance(optimizer, expected_class)
|
params_dict = {"default": model_params}
|
||||||
assert optimizer.defaults["lr"] == config.lr
|
optimizer = config.build(params_dict)
|
||||||
|
assert isinstance(optimizer, expected_class)
|
||||||
|
assert isinstance(optimizer["default"], torch.optim.Adam)
|
||||||
|
assert optimizer["default"].defaults["lr"] == config.lr
|
||||||
|
else:
|
||||||
|
optimizer = config.build(model_params)
|
||||||
|
assert isinstance(optimizer, expected_class)
|
||||||
|
assert optimizer.defaults["lr"] == config.lr
|
||||||
|
|
||||||
|
|
||||||
def test_save_optimizer_state(optimizer, tmp_path):
|
def test_save_optimizer_state(optimizer, tmp_path):
|
||||||
|
@ -54,3 +63,185 @@ def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path):
|
||||||
loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path)
|
loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path)
|
||||||
|
|
||||||
torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict())
|
torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def base_params_dict():
|
||||||
|
return {
|
||||||
|
"actor": [torch.nn.Parameter(torch.randn(10, 10))],
|
||||||
|
"critic": [torch.nn.Parameter(torch.randn(5, 5))],
|
||||||
|
"temperature": [torch.nn.Parameter(torch.randn(3, 3))],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"config_params, expected_values",
|
||||||
|
[
|
||||||
|
# Test 1: Basic configuration with different learning rates
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"lr": 1e-3,
|
||||||
|
"weight_decay": 1e-4,
|
||||||
|
"optimizer_groups": {
|
||||||
|
"actor": {"lr": 1e-4},
|
||||||
|
"critic": {"lr": 5e-4},
|
||||||
|
"temperature": {"lr": 2e-3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
|
||||||
|
"critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
|
||||||
|
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
# Test 2: Different weight decays and beta values
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"lr": 1e-3,
|
||||||
|
"weight_decay": 1e-4,
|
||||||
|
"optimizer_groups": {
|
||||||
|
"actor": {"lr": 1e-4, "weight_decay": 1e-5},
|
||||||
|
"critic": {"lr": 5e-4, "weight_decay": 1e-6},
|
||||||
|
"temperature": {"lr": 2e-3, "betas": (0.95, 0.999)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"actor": {"lr": 1e-4, "weight_decay": 1e-5, "betas": (0.9, 0.999)},
|
||||||
|
"critic": {"lr": 5e-4, "weight_decay": 1e-6, "betas": (0.9, 0.999)},
|
||||||
|
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.95, 0.999)},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
# Test 3: Epsilon parameter customization
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"lr": 1e-3,
|
||||||
|
"weight_decay": 1e-4,
|
||||||
|
"optimizer_groups": {
|
||||||
|
"actor": {"lr": 1e-4, "eps": 1e-6},
|
||||||
|
"critic": {"lr": 5e-4, "eps": 1e-7},
|
||||||
|
"temperature": {"lr": 2e-3, "eps": 1e-8},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-6},
|
||||||
|
"critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-7},
|
||||||
|
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-8},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_multi_adam_configuration(base_params_dict, config_params, expected_values):
|
||||||
|
# Create config with the given parameters
|
||||||
|
config = MultiAdamConfig(**config_params)
|
||||||
|
optimizers = config.build(base_params_dict)
|
||||||
|
|
||||||
|
# Verify optimizer count and keys
|
||||||
|
assert len(optimizers) == len(expected_values)
|
||||||
|
assert set(optimizers.keys()) == set(expected_values.keys())
|
||||||
|
|
||||||
|
# Check that all optimizers are Adam instances
|
||||||
|
for opt in optimizers.values():
|
||||||
|
assert isinstance(opt, torch.optim.Adam)
|
||||||
|
|
||||||
|
# Verify hyperparameters for each optimizer
|
||||||
|
for name, expected in expected_values.items():
|
||||||
|
optimizer = optimizers[name]
|
||||||
|
for param, value in expected.items():
|
||||||
|
assert optimizer.defaults[param] == value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def multi_optimizers(base_params_dict):
|
||||||
|
config = MultiAdamConfig(
|
||||||
|
lr=1e-3,
|
||||||
|
optimizer_groups={
|
||||||
|
"actor": {"lr": 1e-4},
|
||||||
|
"critic": {"lr": 5e-4},
|
||||||
|
"temperature": {"lr": 2e-3},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return config.build(base_params_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_multi_optimizer_state(multi_optimizers, tmp_path):
|
||||||
|
# Save optimizer states
|
||||||
|
save_optimizer_state(multi_optimizers, tmp_path)
|
||||||
|
|
||||||
|
# Verify that directories were created for each optimizer
|
||||||
|
for name in multi_optimizers.keys():
|
||||||
|
assert (tmp_path / name).is_dir()
|
||||||
|
assert (tmp_path / name / OPTIMIZER_STATE).is_file()
|
||||||
|
assert (tmp_path / name / OPTIMIZER_PARAM_GROUPS).is_file()
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers, tmp_path):
|
||||||
|
# Option 1: Add a minimal backward pass to populate optimizer states
|
||||||
|
for name, params in base_params_dict.items():
|
||||||
|
if name in multi_optimizers:
|
||||||
|
# Create a dummy loss and do backward
|
||||||
|
dummy_loss = params[0].sum()
|
||||||
|
dummy_loss.backward()
|
||||||
|
# Perform an optimization step
|
||||||
|
multi_optimizers[name].step()
|
||||||
|
# Zero gradients for next steps
|
||||||
|
multi_optimizers[name].zero_grad()
|
||||||
|
|
||||||
|
# Save optimizer states
|
||||||
|
save_optimizer_state(multi_optimizers, tmp_path)
|
||||||
|
|
||||||
|
# Create new optimizers with the same config
|
||||||
|
config = MultiAdamConfig(
|
||||||
|
lr=1e-3,
|
||||||
|
optimizer_groups={
|
||||||
|
"actor": {"lr": 1e-4},
|
||||||
|
"critic": {"lr": 5e-4},
|
||||||
|
"temperature": {"lr": 2e-3},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
new_optimizers = config.build(base_params_dict)
|
||||||
|
|
||||||
|
# Load optimizer states
|
||||||
|
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
||||||
|
|
||||||
|
# Verify state dictionaries match
|
||||||
|
for name in multi_optimizers.keys():
|
||||||
|
torch.testing.assert_close(
|
||||||
|
multi_optimizers[name].state_dict(),
|
||||||
|
loaded_optimizers[name].state_dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
|
||||||
|
"""Test saving and loading optimizer states even when the state is empty (no backward pass)."""
|
||||||
|
# Create config and build optimizers
|
||||||
|
config = MultiAdamConfig(
|
||||||
|
lr=1e-3,
|
||||||
|
optimizer_groups={
|
||||||
|
"actor": {"lr": 1e-4},
|
||||||
|
"critic": {"lr": 5e-4},
|
||||||
|
"temperature": {"lr": 2e-3},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
optimizers = config.build(base_params_dict)
|
||||||
|
|
||||||
|
# Save optimizer states without any backward pass (empty state)
|
||||||
|
save_optimizer_state(optimizers, tmp_path)
|
||||||
|
|
||||||
|
# Create new optimizers with the same config
|
||||||
|
new_optimizers = config.build(base_params_dict)
|
||||||
|
|
||||||
|
# Load optimizer states
|
||||||
|
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
||||||
|
|
||||||
|
# Verify hyperparameters match even with empty state
|
||||||
|
for name, optimizer in optimizers.items():
|
||||||
|
assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"]
|
||||||
|
assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"]
|
||||||
|
assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"]
|
||||||
|
|
||||||
|
# Verify state dictionaries match (they will be empty)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
optimizer.state_dict()["param_groups"],
|
||||||
|
loaded_optimizers[name].state_dict()["param_groups"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue