Add Automatic Mixed Precision option for training and evaluation. (#199)

This commit is contained in:
Alexander Soare 2024-05-20 18:57:54 +01:00 committed by GitHub
parent 2b270d085b
commit b6c216b590
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 103 additions and 33 deletions

View File

@ -20,6 +20,8 @@ build-gpu:
test-end-to-end: test-end-to-end:
${MAKE} test-act-ete-train ${MAKE} test-act-ete-train
${MAKE} test-act-ete-eval ${MAKE} test-act-ete-eval
${MAKE} test-act-ete-train-amp
${MAKE} test-act-ete-eval-amp
${MAKE} test-diffusion-ete-train ${MAKE} test-diffusion-ete-train
${MAKE} test-diffusion-ete-eval ${MAKE} test-diffusion-ete-eval
${MAKE} test-tdmpc-ete-train ${MAKE} test-tdmpc-ete-train
@ -29,6 +31,7 @@ test-end-to-end:
test-act-ete-train: test-act-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=act \ policy=act \
policy.dim_model=64 \
env=aloha \ env=aloha \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=2 \ training.offline_steps=2 \
@ -51,9 +54,40 @@ test-act-ete-eval:
env.episode_length=8 \ env.episode_length=8 \
device=cpu \ device=cpu \
test-act-ete-train-amp:
python lerobot/scripts/train.py \
policy=act \
policy.dim_model=64 \
env=aloha \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
device=cpu \
training.save_model=true \
training.save_freq=2 \
policy.n_action_steps=20 \
policy.chunk_size=20 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/act/ \
use_amp=true
test-act-ete-eval-amp:
python lerobot/scripts/eval.py \
-p tests/outputs/act/checkpoints/000002 \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=cpu \
use_amp=true
test-diffusion-ete-train: test-diffusion-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=diffusion \ policy=diffusion \
policy.down_dims=\[64,128,256\] \
policy.diffusion_step_embed_dim=32 \
policy.num_inference_steps=10 \
env=pusht \ env=pusht \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=2 \ training.offline_steps=2 \
@ -101,7 +135,6 @@ test-tdmpc-ete-eval:
env.episode_length=8 \ env.episode_length=8 \
device=cpu \ device=cpu \
test-default-ete-eval: test-default-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \ --config lerobot/configs/default.yaml \

View File

@ -10,6 +10,9 @@ hydra:
name: default name: default
device: cuda # cpu device: cuda # cpu
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: false
# `seed` is used for training (eg: model initialization, dataset shuffling) # `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments. # AND for the evaluation environments.
seed: ??? seed: ???

View File

@ -46,6 +46,7 @@ import json
import logging import logging
import threading import threading
import time import time
from contextlib import nullcontext
from copy import deepcopy from copy import deepcopy
from datetime import datetime as dt from datetime import datetime as dt
from pathlib import Path from pathlib import Path
@ -520,7 +521,7 @@ def eval(
raise NotImplementedError() raise NotImplementedError()
# Check device is available # Check device is available
get_safe_torch_device(hydra_cfg.device, log=True) device = get_safe_torch_device(hydra_cfg.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -539,16 +540,17 @@ def eval(
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats) policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
policy.eval() policy.eval()
info = eval_policy( with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
env, info = eval_policy(
policy, env,
hydra_cfg.eval.n_episodes, policy,
max_episodes_rendered=10, hydra_cfg.eval.n_episodes,
video_dir=Path(out_dir) / "eval", max_episodes_rendered=10,
start_seed=hydra_cfg.seed, video_dir=Path(out_dir) / "eval",
enable_progbar=True, start_seed=hydra_cfg.seed,
enable_inner_progbar=True, enable_progbar=True,
) enable_inner_progbar=True,
)
print(info["aggregated"]) print(info["aggregated"])
# Save info # Save info

View File

@ -15,12 +15,14 @@
# limitations under the License. # limitations under the License.
import logging import logging
import time import time
from contextlib import nullcontext
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import hydra import hydra
import torch import torch
from omegaconf import DictConfig from omegaconf import DictConfig
from torch.cuda.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
@ -28,6 +30,7 @@ from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import PolicyWithUpdate from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
get_safe_torch_device, get_safe_torch_device,
@ -83,21 +86,40 @@ def make_optimizer_and_scheduler(cfg, policy):
return optimizer, lr_scheduler return optimizer, lr_scheduler
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): def update_policy(
policy,
batch,
optimizer,
grad_clip_norm,
grad_scaler: GradScaler,
lr_scheduler=None,
use_amp: bool = False,
):
"""Returns a dictionary of items for logging.""" """Returns a dictionary of items for logging."""
start_time = time.time() start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train() policy.train()
output_dict = policy.forward(batch) with torch.autocast(device_type=device.type) if use_amp else nullcontext():
# TODO(rcadene): policy.unnormalize_outputs(out_dict) output_dict = policy.forward(batch)
loss = output_dict["loss"] # TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss.backward() loss = output_dict["loss"]
grad_scaler.scale(loss).backward()
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(), policy.parameters(),
grad_clip_norm, grad_clip_norm,
error_if_nonfinite=False, error_if_nonfinite=False,
) )
optimizer.step() # Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if lr_scheduler is not None: if lr_scheduler is not None:
@ -111,7 +133,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
"loss": loss.item(), "loss": loss.item(),
"grad_norm": float(grad_norm), "grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"], "lr": optimizer.param_groups[0]["lr"],
"update_s": time.time() - start_time, "update_s": time.perf_counter() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"}, **{k: v for k, v in output_dict.items() if k != "loss"},
} }
@ -219,7 +241,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
raise NotImplementedError("Online training is not implemented yet.") raise NotImplementedError("Online training is not implemented yet.")
# Check device is available # Check device is available
get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -237,6 +259,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Create optimizer and scheduler # Create optimizer and scheduler
# Temporary hack to move optimizer out of policy # Temporary hack to move optimizer out of policy
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(enabled=cfg.use_amp)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
@ -257,14 +280,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
def evaluate_and_checkpoint_if_needed(step): def evaluate_and_checkpoint_if_needed(step):
if step % cfg.training.eval_freq == 0: if step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
eval_info = eval_policy( with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
eval_env, eval_info = eval_policy(
policy, eval_env,
cfg.eval.n_episodes, policy,
video_dir=Path(out_dir) / "eval", cfg.eval.n_episodes,
max_episodes_rendered=4, video_dir=Path(out_dir) / "eval",
start_seed=cfg.seed, max_episodes_rendered=4,
) start_seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
if cfg.wandb.enable: if cfg.wandb.enable:
logger.log_video(eval_info["video_paths"][0], step, mode="eval") logger.log_video(eval_info["video_paths"][0], step, mode="eval")
@ -288,7 +312,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
num_workers=4, num_workers=4,
batch_size=cfg.training.batch_size, batch_size=cfg.training.batch_size,
shuffle=True, shuffle=True,
pin_memory=cfg.device != "cpu", pin_memory=device.type != "cpu",
drop_last=False, drop_last=False,
) )
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
@ -301,9 +325,17 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
batch = next(dl_iter) batch = next(dl_iter)
for key in batch: for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True) batch[key] = batch[key].to(device, non_blocking=True)
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler) train_info = update_policy(
policy,
batch,
optimizer,
cfg.training.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.training.log_freq == 0: if step % cfg.training.log_freq == 0:
@ -329,7 +361,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
num_workers=4, num_workers=4,
batch_size=cfg.training.batch_size, batch_size=cfg.training.batch_size,
sampler=sampler, sampler=sampler,
pin_memory=cfg.device != "cpu", pin_memory=device.type != "cpu",
drop_last=False, drop_last=False,
) )