119 lines
3.6 KiB
Python
119 lines
3.6 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import abc
|
|
from dataclasses import asdict, dataclass
|
|
from pathlib import Path
|
|
|
|
import draccus
|
|
import torch
|
|
from safetensors.torch import load_file, save_file
|
|
|
|
from lerobot.common.constants import (
|
|
OPTIMIZER_PARAM_GROUPS,
|
|
OPTIMIZER_STATE,
|
|
)
|
|
from lerobot.common.datasets.utils import flatten_dict, unflatten_dict, write_json
|
|
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
|
|
|
|
|
@dataclass
|
|
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
|
lr: float
|
|
weight_decay: float
|
|
grad_clip_norm: float
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
return self.get_choice_name(self.__class__)
|
|
|
|
@classmethod
|
|
def default_choice_name(cls) -> str | None:
|
|
return "adam"
|
|
|
|
@abc.abstractmethod
|
|
def build(self) -> torch.optim.Optimizer:
|
|
raise NotImplementedError
|
|
|
|
|
|
@OptimizerConfig.register_subclass("adam")
|
|
@dataclass
|
|
class AdamConfig(OptimizerConfig):
|
|
lr: float = 1e-3
|
|
betas: tuple[float, float] = (0.9, 0.999)
|
|
eps: float = 1e-8
|
|
weight_decay: float = 0.0
|
|
grad_clip_norm: float = 10.0
|
|
|
|
def build(self, params: dict) -> torch.optim.Optimizer:
|
|
kwargs = asdict(self)
|
|
kwargs.pop("grad_clip_norm")
|
|
return torch.optim.Adam(params, **kwargs)
|
|
|
|
|
|
@OptimizerConfig.register_subclass("adamw")
|
|
@dataclass
|
|
class AdamWConfig(OptimizerConfig):
|
|
lr: float = 1e-3
|
|
betas: tuple[float, float] = (0.9, 0.999)
|
|
eps: float = 1e-8
|
|
weight_decay: float = 1e-2
|
|
grad_clip_norm: float = 10.0
|
|
|
|
def build(self, params: dict) -> torch.optim.Optimizer:
|
|
kwargs = asdict(self)
|
|
kwargs.pop("grad_clip_norm")
|
|
return torch.optim.AdamW(params, **kwargs)
|
|
|
|
|
|
@OptimizerConfig.register_subclass("sgd")
|
|
@dataclass
|
|
class SGDConfig(OptimizerConfig):
|
|
lr: float = 1e-3
|
|
momentum: float = 0.0
|
|
dampening: float = 0.0
|
|
nesterov: bool = False
|
|
weight_decay: float = 0.0
|
|
grad_clip_norm: float = 10.0
|
|
|
|
def build(self, params: dict) -> torch.optim.Optimizer:
|
|
kwargs = asdict(self)
|
|
kwargs.pop("grad_clip_norm")
|
|
return torch.optim.SGD(params, **kwargs)
|
|
|
|
|
|
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
|
state = optimizer.state_dict()
|
|
param_groups = state.pop("param_groups")
|
|
flat_state = flatten_dict(state)
|
|
save_file(flat_state, save_dir / OPTIMIZER_STATE)
|
|
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
|
|
|
|
|
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
|
current_state_dict = optimizer.state_dict()
|
|
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
|
state = unflatten_dict(flat_state)
|
|
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
|
|
|
if "param_groups" in current_state_dict:
|
|
param_groups = deserialize_json_into_object(
|
|
save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"]
|
|
)
|
|
loaded_state_dict["param_groups"] = param_groups
|
|
|
|
optimizer.load_state_dict(loaded_state_dict)
|
|
return optimizer
|