diff --git a/Makefile b/Makefile index 9df4fe60..9a8a2474 100644 --- a/Makefile +++ b/Makefile @@ -20,6 +20,8 @@ build-gpu: test-end-to-end: ${MAKE} test-act-ete-train ${MAKE} test-act-ete-eval + ${MAKE} test-act-ete-train-amp + ${MAKE} test-act-ete-eval-amp ${MAKE} test-diffusion-ete-train ${MAKE} test-diffusion-ete-eval ${MAKE} test-tdmpc-ete-train @@ -29,6 +31,7 @@ test-end-to-end: test-act-ete-train: python lerobot/scripts/train.py \ policy=act \ + policy.dim_model=64 \ env=aloha \ wandb.enable=False \ training.offline_steps=2 \ @@ -51,9 +54,40 @@ test-act-ete-eval: env.episode_length=8 \ device=cpu \ +test-act-ete-train-amp: + python lerobot/scripts/train.py \ + policy=act \ + policy.dim_model=64 \ + env=aloha \ + wandb.enable=False \ + training.offline_steps=2 \ + training.online_steps=0 \ + eval.n_episodes=1 \ + eval.batch_size=1 \ + device=cpu \ + training.save_model=true \ + training.save_freq=2 \ + policy.n_action_steps=20 \ + policy.chunk_size=20 \ + training.batch_size=2 \ + hydra.run.dir=tests/outputs/act/ \ + use_amp=true + +test-act-ete-eval-amp: + python lerobot/scripts/eval.py \ + -p tests/outputs/act/checkpoints/000002 \ + eval.n_episodes=1 \ + eval.batch_size=1 \ + env.episode_length=8 \ + device=cpu \ + use_amp=true + test-diffusion-ete-train: python lerobot/scripts/train.py \ policy=diffusion \ + policy.down_dims=\[64,128,256\] \ + policy.diffusion_step_embed_dim=32 \ + policy.num_inference_steps=10 \ env=pusht \ wandb.enable=False \ training.offline_steps=2 \ @@ -101,7 +135,6 @@ test-tdmpc-ete-eval: env.episode_length=8 \ device=cpu \ - test-default-ete-eval: python lerobot/scripts/eval.py \ --config lerobot/configs/default.yaml \ diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index e35deba1..27329bc9 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -10,6 +10,9 @@ hydra: name: default device: cuda # cpu +# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, +# automatic gradient scaling is used. +use_amp: false # `seed` is used for training (eg: model initialization, dataset shuffling) # AND for the evaluation environments. seed: ??? diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 9c95633a..7e4690d0 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -46,6 +46,7 @@ import json import logging import threading import time +from contextlib import nullcontext from copy import deepcopy from datetime import datetime as dt from pathlib import Path @@ -520,7 +521,7 @@ def eval( raise NotImplementedError() # Check device is available - get_safe_torch_device(hydra_cfg.device, log=True) + device = get_safe_torch_device(hydra_cfg.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -539,16 +540,17 @@ def eval( policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats) policy.eval() - info = eval_policy( - env, - policy, - hydra_cfg.eval.n_episodes, - max_episodes_rendered=10, - video_dir=Path(out_dir) / "eval", - start_seed=hydra_cfg.seed, - enable_progbar=True, - enable_inner_progbar=True, - ) + with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(): + info = eval_policy( + env, + policy, + hydra_cfg.eval.n_episodes, + max_episodes_rendered=10, + video_dir=Path(out_dir) / "eval", + start_seed=hydra_cfg.seed, + enable_progbar=True, + enable_inner_progbar=True, + ) print(info["aggregated"]) # Save info diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 58a2bd01..2b28943d 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -15,12 +15,14 @@ # limitations under the License. import logging import time +from contextlib import nullcontext from copy import deepcopy from pathlib import Path import hydra import torch from omegaconf import DictConfig +from torch.cuda.amp import GradScaler from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle @@ -28,6 +30,7 @@ from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy from lerobot.common.policies.policy_protocol import PolicyWithUpdate +from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.utils.utils import ( format_big_number, get_safe_torch_device, @@ -83,21 +86,40 @@ def make_optimizer_and_scheduler(cfg, policy): return optimizer, lr_scheduler -def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): +def update_policy( + policy, + batch, + optimizer, + grad_clip_norm, + grad_scaler: GradScaler, + lr_scheduler=None, + use_amp: bool = False, +): """Returns a dictionary of items for logging.""" - start_time = time.time() + start_time = time.perf_counter() + device = get_device_from_parameters(policy) policy.train() - output_dict = policy.forward(batch) - # TODO(rcadene): policy.unnormalize_outputs(out_dict) - loss = output_dict["loss"] - loss.backward() + with torch.autocast(device_type=device.type) if use_amp else nullcontext(): + output_dict = policy.forward(batch) + # TODO(rcadene): policy.unnormalize_outputs(out_dict) + loss = output_dict["loss"] + grad_scaler.scale(loss).backward() + + # Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**. + grad_scaler.unscale_(optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_( policy.parameters(), grad_clip_norm, error_if_nonfinite=False, ) - optimizer.step() + # Optimizer's gradients are already unscaled, so scaler.step does not unscale them, + # although it still skips optimizer.step() if the gradients contain infs or NaNs. + grad_scaler.step(optimizer) + # Updates the scale for next iteration. + grad_scaler.update() + optimizer.zero_grad() if lr_scheduler is not None: @@ -111,7 +133,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): "loss": loss.item(), "grad_norm": float(grad_norm), "lr": optimizer.param_groups[0]["lr"], - "update_s": time.time() - start_time, + "update_s": time.perf_counter() - start_time, **{k: v for k, v in output_dict.items() if k != "loss"}, } @@ -219,7 +241,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No raise NotImplementedError("Online training is not implemented yet.") # Check device is available - get_safe_torch_device(cfg.device, log=True) + device = get_safe_torch_device(cfg.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -237,6 +259,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # Create optimizer and scheduler # Temporary hack to move optimizer out of policy optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) + grad_scaler = GradScaler(enabled=cfg.use_amp) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) @@ -257,14 +280,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No def evaluate_and_checkpoint_if_needed(step): if step % cfg.training.eval_freq == 0: logging.info(f"Eval policy at step {step}") - eval_info = eval_policy( - eval_env, - policy, - cfg.eval.n_episodes, - video_dir=Path(out_dir) / "eval", - max_episodes_rendered=4, - start_seed=cfg.seed, - ) + with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): + eval_info = eval_policy( + eval_env, + policy, + cfg.eval.n_episodes, + video_dir=Path(out_dir) / "eval", + max_episodes_rendered=4, + start_seed=cfg.seed, + ) log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) if cfg.wandb.enable: logger.log_video(eval_info["video_paths"][0], step, mode="eval") @@ -288,7 +312,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No num_workers=4, batch_size=cfg.training.batch_size, shuffle=True, - pin_memory=cfg.device != "cpu", + pin_memory=device.type != "cpu", drop_last=False, ) dl_iter = cycle(dataloader) @@ -301,9 +325,17 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No batch = next(dl_iter) for key in batch: - batch[key] = batch[key].to(cfg.device, non_blocking=True) + batch[key] = batch[key].to(device, non_blocking=True) - train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler) + train_info = update_policy( + policy, + batch, + optimizer, + cfg.training.grad_clip_norm, + grad_scaler=grad_scaler, + lr_scheduler=lr_scheduler, + use_amp=cfg.use_amp, + ) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.training.log_freq == 0: @@ -329,7 +361,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No num_workers=4, batch_size=cfg.training.batch_size, sampler=sampler, - pin_memory=cfg.device != "cpu", + pin_memory=device.type != "cpu", drop_last=False, )