219 lines
7.5 KiB
Python
219 lines
7.5 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import abc
|
|
from dataclasses import asdict, dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import draccus
|
|
import torch
|
|
from safetensors.torch import load_file, save_file
|
|
|
|
from lerobot.common.constants import (
|
|
OPTIMIZER_PARAM_GROUPS,
|
|
OPTIMIZER_STATE,
|
|
)
|
|
from lerobot.common.datasets.utils import flatten_dict, unflatten_dict, write_json
|
|
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
|
|
|
|
|
@dataclass
|
|
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
|
lr: float
|
|
weight_decay: float
|
|
grad_clip_norm: float
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
return self.get_choice_name(self.__class__)
|
|
|
|
@classmethod
|
|
def default_choice_name(cls) -> str | None:
|
|
return "adam"
|
|
|
|
@abc.abstractmethod
|
|
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
|
raise NotImplementedError
|
|
|
|
|
|
@OptimizerConfig.register_subclass("adam")
|
|
@dataclass
|
|
class AdamConfig(OptimizerConfig):
|
|
lr: float = 1e-3
|
|
betas: tuple[float, float] = (0.9, 0.999)
|
|
eps: float = 1e-8
|
|
weight_decay: float = 0.0
|
|
grad_clip_norm: float = 10.0
|
|
|
|
def build(self, params: dict) -> torch.optim.Optimizer:
|
|
kwargs = asdict(self)
|
|
kwargs.pop("grad_clip_norm")
|
|
return torch.optim.Adam(params, **kwargs)
|
|
|
|
|
|
@OptimizerConfig.register_subclass("adamw")
|
|
@dataclass
|
|
class AdamWConfig(OptimizerConfig):
|
|
lr: float = 1e-3
|
|
betas: tuple[float, float] = (0.9, 0.999)
|
|
eps: float = 1e-8
|
|
weight_decay: float = 1e-2
|
|
grad_clip_norm: float = 10.0
|
|
|
|
def build(self, params: dict) -> torch.optim.Optimizer:
|
|
kwargs = asdict(self)
|
|
kwargs.pop("grad_clip_norm")
|
|
return torch.optim.AdamW(params, **kwargs)
|
|
|
|
|
|
@OptimizerConfig.register_subclass("sgd")
|
|
@dataclass
|
|
class SGDConfig(OptimizerConfig):
|
|
lr: float = 1e-3
|
|
momentum: float = 0.0
|
|
dampening: float = 0.0
|
|
nesterov: bool = False
|
|
weight_decay: float = 0.0
|
|
grad_clip_norm: float = 10.0
|
|
|
|
def build(self, params: dict) -> torch.optim.Optimizer:
|
|
kwargs = asdict(self)
|
|
kwargs.pop("grad_clip_norm")
|
|
return torch.optim.SGD(params, **kwargs)
|
|
|
|
|
|
@OptimizerConfig.register_subclass("multi_adam")
|
|
@dataclass
|
|
class MultiAdamConfig(OptimizerConfig):
|
|
"""Configuration for multiple Adam optimizers with different parameter groups.
|
|
|
|
This creates a dictionary of Adam optimizers, each with its own hyperparameters.
|
|
|
|
Args:
|
|
lr: Default learning rate (used if not specified for a group)
|
|
weight_decay: Default weight decay (used if not specified for a group)
|
|
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
|
|
grad_clip_norm: Gradient clipping norm
|
|
"""
|
|
lr: float = 1e-3
|
|
weight_decay: float = 0.0
|
|
grad_clip_norm: float = 10.0
|
|
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
|
|
|
|
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
|
|
"""Build multiple Adam optimizers.
|
|
|
|
Args:
|
|
params_dict: Dictionary mapping parameter group names to lists of parameters
|
|
The keys should match the keys in optimizer_groups
|
|
|
|
Returns:
|
|
Dictionary mapping parameter group names to their optimizers
|
|
"""
|
|
optimizers = {}
|
|
|
|
for name, params in params_dict.items():
|
|
# Get group-specific hyperparameters or use defaults
|
|
group_config = self.optimizer_groups.get(name, {})
|
|
|
|
# Create optimizer with merged parameters (defaults + group-specific)
|
|
optimizer_kwargs = {
|
|
"lr": group_config.get("lr", self.lr),
|
|
"betas": group_config.get("betas", (0.9, 0.999)),
|
|
"eps": group_config.get("eps", 1e-5),
|
|
"weight_decay": group_config.get("weight_decay", self.weight_decay),
|
|
}
|
|
|
|
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
|
|
|
|
return optimizers
|
|
|
|
|
|
def save_optimizer_state(optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path) -> None:
|
|
"""Save optimizer state to disk.
|
|
|
|
Args:
|
|
optimizer: Either a single optimizer or a dictionary of optimizers.
|
|
save_dir: Directory to save the optimizer state.
|
|
"""
|
|
if isinstance(optimizer, dict):
|
|
# Handle dictionary of optimizers
|
|
for name, opt in optimizer.items():
|
|
optimizer_dir = save_dir / name
|
|
optimizer_dir.mkdir(exist_ok=True, parents=True)
|
|
_save_single_optimizer_state(opt, optimizer_dir)
|
|
else:
|
|
# Handle single optimizer
|
|
_save_single_optimizer_state(optimizer, save_dir)
|
|
|
|
|
|
def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
|
"""Save a single optimizer's state to disk."""
|
|
state = optimizer.state_dict()
|
|
param_groups = state.pop("param_groups")
|
|
flat_state = flatten_dict(state)
|
|
save_file(flat_state, save_dir / OPTIMIZER_STATE)
|
|
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
|
|
|
|
|
def load_optimizer_state(
|
|
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
|
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
|
"""Load optimizer state from disk.
|
|
|
|
Args:
|
|
optimizer: Either a single optimizer or a dictionary of optimizers.
|
|
save_dir: Directory to load the optimizer state from.
|
|
|
|
Returns:
|
|
The updated optimizer(s) with loaded state.
|
|
"""
|
|
if isinstance(optimizer, dict):
|
|
# Handle dictionary of optimizers
|
|
loaded_optimizers = {}
|
|
for name, opt in optimizer.items():
|
|
optimizer_dir = save_dir / name
|
|
if optimizer_dir.exists():
|
|
loaded_optimizers[name] = _load_single_optimizer_state(opt, optimizer_dir)
|
|
else:
|
|
loaded_optimizers[name] = opt
|
|
return loaded_optimizers
|
|
else:
|
|
# Handle single optimizer
|
|
return _load_single_optimizer_state(optimizer, save_dir)
|
|
|
|
|
|
def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
|
"""Load a single optimizer's state from disk."""
|
|
current_state_dict = optimizer.state_dict()
|
|
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
|
state = unflatten_dict(flat_state)
|
|
|
|
# Handle case where 'state' key might not exist (for newly created optimizers)
|
|
if "state" in state:
|
|
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
|
else:
|
|
loaded_state_dict = {"state": {}}
|
|
|
|
if "param_groups" in current_state_dict:
|
|
param_groups = deserialize_json_into_object(
|
|
save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"]
|
|
)
|
|
loaded_state_dict["param_groups"] = param_groups
|
|
|
|
optimizer.load_state_dict(loaded_state_dict)
|
|
return optimizer
|