supporting accelerate

This commit is contained in:
mshukor 2025-02-25 21:38:17 +01:00
parent 8699a28be0
commit 86e8c1d997
4 changed files with 101 additions and 41 deletions

View File

@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any from typing import Any, Callable
from lerobot.common.utils.utils import format_big_number from lerobot.common.utils.utils import format_big_number
@ -93,12 +93,14 @@ class MetricsTracker:
num_episodes: int, num_episodes: int,
metrics: dict[str, AverageMeter], metrics: dict[str, AverageMeter],
initial_step: int = 0, initial_step: int = 0,
accelerator: Callable = None,
): ):
self.__dict__.update({k: None for k in self.__keys__}) self.__dict__.update({k: None for k in self.__keys__})
self._batch_size = batch_size self._batch_size = batch_size
self._num_frames = num_frames self._num_frames = num_frames
self._avg_samples_per_ep = num_frames / num_episodes self._avg_samples_per_ep = num_frames / num_episodes
self.metrics = metrics self.metrics = metrics
self.accelerator = accelerator
self.steps = initial_step self.steps = initial_step
# A sample is an (observation,action) pair, where observation and action # A sample is an (observation,action) pair, where observation and action
@ -128,7 +130,7 @@ class MetricsTracker:
Updates metrics that depend on 'step' for one step. Updates metrics that depend on 'step' for one step.
""" """
self.steps += 1 self.steps += 1
self.samples += self._batch_size self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
self.episodes = self.samples / self._avg_samples_per_ep self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames self.epochs = self.samples / self._num_frames

View File

@ -16,7 +16,7 @@
import random import random
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Generator from typing import Any, Generator, Callable
import numpy as np import numpy as np
import torch import torch
@ -163,14 +163,16 @@ def set_rng_state(random_state_dict: dict[str, Any]):
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"]) torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
def set_seed(seed) -> None: def set_seed(seed, accelerator: Callable = None) -> None:
"""Set seed for reproducibility.""" """Set seed for reproducibility."""
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
if accelerator:
from accelerate.utils import set_seed
set_seed(seed)
@contextmanager @contextmanager
def seeded_context(seed: int) -> Generator[None, None, None]: def seeded_context(seed: int) -> Generator[None, None, None]:

View File

@ -20,10 +20,10 @@ import platform
from copy import copy from copy import copy
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Callable
import numpy as np import numpy as np
import torch import torch
from typing import Any
def none_or_int(value): def none_or_int(value):
if value == "None": if value == "None":
@ -50,12 +50,12 @@ def auto_select_torch_device() -> torch.device:
return torch.device("cpu") return torch.device("cpu")
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: def get_safe_torch_device(try_device: str, log: bool = False, accelerator: Callable = None) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available.""" """Given a string, return a torch.device with checks on whether the device is available."""
match try_device: match try_device:
case "cuda": case "cuda":
assert torch.cuda.is_available() assert torch.cuda.is_available()
device = torch.device("cuda") device = accelerator.device if accelerator else torch.device("cuda")
case "mps": case "mps":
assert torch.backends.mps.is_available() assert torch.backends.mps.is_available()
device = torch.device("mps") device = torch.device("mps")
@ -103,7 +103,7 @@ def is_amp_available(device: str):
raise ValueError(f"Unknown device '{device}.") raise ValueError(f"Unknown device '{device}.")
def init_logging(): def init_logging(accelerator: Callable = None):
def custom_format(record): def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}" fnameline = f"{record.pathname}:{record.lineno}"
@ -120,7 +120,10 @@ def init_logging():
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter) console_handler.setFormatter(formatter)
logging.getLogger().addHandler(console_handler) logging.getLogger().addHandler(console_handler)
if accelerator is not None and not accelerator.is_main_process:
# Disable duplicate logging on non-main processes
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
logging.getLogger().setLevel(logging.WARNING)
def format_big_number(num, precision=0): def format_big_number(num, precision=0):
suffixes = ["", "K", "M", "B", "T", "Q"] suffixes = ["", "K", "M", "B", "T", "Q"]
@ -216,3 +219,18 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
except TypeError: except TypeError:
# If a TypeError is raised, the string is not a valid dtype # If a TypeError is raised, the string is not a valid dtype
return False return False
def is_launched_with_accelerate() -> bool:
return "ACCELERATE_MIXED_PRECISION" in os.environ
def get_accelerate_config(accelerator: Callable = None) -> dict[str, Any]:
config = {}
if not accelerator:
return config
config["num_processes"] = accelerator.num_processes
config["device"] = str(accelerator.device)
config["distributed_type"] = str(accelerator.distributed_type)
config["mixed_precision"] = accelerator.mixed_precision
config["gradient_accumulation_steps"] = accelerator.gradient_accumulation_steps
return config

View File

@ -17,7 +17,7 @@ import logging
import time import time
from contextlib import nullcontext from contextlib import nullcontext
from pprint import pformat from pprint import pformat
from typing import Any from typing import Any, Callable
import torch import torch
from termcolor import colored from termcolor import colored
@ -46,6 +46,8 @@ from lerobot.common.utils.utils import (
get_safe_torch_device, get_safe_torch_device,
has_method, has_method,
init_logging, init_logging,
get_accelerate_config,
is_launched_with_accelerate
) )
from lerobot.common.utils.wandb_utils import WandBLogger from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser from lerobot.configs import parser
@ -63,15 +65,26 @@ def update_policy(
lr_scheduler=None, lr_scheduler=None,
use_amp: bool = False, use_amp: bool = False,
lock=None, lock=None,
accelerator: Callable = None,
) -> tuple[MetricsTracker, dict]: ) -> tuple[MetricsTracker, dict]:
start_time = time.perf_counter() start_time = time.perf_counter()
device = get_device_from_parameters(policy) device = get_device_from_parameters(policy)
policy.train() policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext(): with torch.autocast(device_type=device.type) if use_amp and accelerator is None else nullcontext():
loss, output_dict = policy.forward(batch) loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict) # TODO(rcadene): policy.unnormalize_outputs(out_dict)
grad_scaler.scale(loss).backward()
if accelerator:
accelerator.backward(loss)
accelerator.unscale_gradients(optimizer=optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
optimizer.step()
else:
grad_scaler.scale(loss).backward()
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**. # Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer) grad_scaler.unscale_(optimizer)
@ -94,6 +107,10 @@ def update_policy(
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step() lr_scheduler.step()
if accelerator:
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): # FIXME(mshukor): avoid accelerator.unwrap_model ?
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
else:
if has_method(policy, "update"): if has_method(policy, "update"):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update() policy.update()
@ -106,10 +123,14 @@ def update_policy(
@parser.wrap() @parser.wrap()
def train(cfg: TrainPipelineConfig): def train(cfg: TrainPipelineConfig, accelerator: Callable = None):
cfg.validate() cfg.validate()
logging.info(pformat(cfg.to_dict())) logging.info(pformat(cfg.to_dict()))
if accelerator and not accelerator.is_main_process:
# Disable logging on non-main processes.
cfg.wandb.enable = False
if cfg.wandb.enable and cfg.wandb.project: if cfg.wandb.enable and cfg.wandb.project:
wandb_logger = WandBLogger(cfg) wandb_logger = WandBLogger(cfg)
else: else:
@ -117,10 +138,10 @@ def train(cfg: TrainPipelineConfig):
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
if cfg.seed is not None: if cfg.seed is not None:
set_seed(cfg.seed) set_seed(cfg.seed, accelerator=accelerator)
# Check device is available # Check device is available
device = get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.device, log=True, accelerator=accelerator)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -141,7 +162,7 @@ def train(cfg: TrainPipelineConfig):
device=device, device=device,
ds_meta=dataset.meta, ds_meta=dataset.meta,
) )
policy.to(device)
logging.info("Creating optimizer and scheduler") logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device, enabled=cfg.use_amp) grad_scaler = GradScaler(device, enabled=cfg.use_amp)
@ -184,6 +205,10 @@ def train(cfg: TrainPipelineConfig):
pin_memory=device.type != "cpu", pin_memory=device.type != "cpu",
drop_last=False, drop_last=False,
) )
if accelerator:
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
policy.train() policy.train()
@ -197,7 +222,7 @@ def train(cfg: TrainPipelineConfig):
} }
train_tracker = MetricsTracker( train_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step, accelerator=accelerator
) )
logging.info("Start offline training on a fixed dataset") logging.info("Start offline training on a fixed dataset")
@ -219,6 +244,7 @@ def train(cfg: TrainPipelineConfig):
grad_scaler=grad_scaler, grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp, use_amp=cfg.use_amp,
accelerator=accelerator,
) )
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@ -238,21 +264,26 @@ def train(cfg: TrainPipelineConfig):
wandb_logger.log_dict(wandb_log_dict, step) wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages() train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step: if cfg.save_checkpoint and is_saving_step and (not accelerator or accelerator.is_main_process):
logging.info(f"Checkpoint policy after step {step}") logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler) save_checkpoint(checkpoint_dir, step, cfg, policy if not accelerator else accelerator.unwrap_model(policy), optimizer, lr_scheduler)
update_last_checkpoint(checkpoint_dir) update_last_checkpoint(checkpoint_dir)
if wandb_logger: if wandb_logger:
wandb_logger.log_policy(checkpoint_dir) wandb_logger.log_policy(checkpoint_dir)
if accelerator:
accelerator.wait_for_everyone()
if cfg.env and is_eval_step: if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps) step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): with (
torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.use_amp and not accelerator else nullcontext(),
):
eval_info = eval_policy( eval_info = eval_policy(
eval_env, eval_env,
policy, policy if not accelerator else accelerator.unwrap_model(policy),
cfg.eval.n_episodes, cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4, max_episodes_rendered=4,
@ -265,7 +296,7 @@ def train(cfg: TrainPipelineConfig):
"eval_s": AverageMeter("eval_s", ":.3f"), "eval_s": AverageMeter("eval_s", ":.3f"),
} }
eval_tracker = MetricsTracker( eval_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step, accelerator=None
) )
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s") eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward") eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
@ -283,4 +314,11 @@ def train(cfg: TrainPipelineConfig):
if __name__ == "__main__": if __name__ == "__main__":
init_logging() init_logging()
if is_launched_with_accelerate():
import accelerate
# We set step_scheduler_with_optimizer False to prevent accelerate from
# adjusting the lr_scheduler steps based on the num_processes
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False)
train(accelerator=accelerator)
else:
train() train()