This commit is contained in:
Alexander Soare 2024-05-20 12:57:40 +01:00
parent 096149b118
commit 304f83fb5c
2 changed files with 65 additions and 22 deletions

View File

@ -10,6 +10,9 @@ hydra:
name: default name: default
device: cuda # cpu device: cuda # cpu
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: false
# `seed` is used for training (eg: model initialization, dataset shuffling) # `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments. # AND for the evaluation environments.
seed: ??? seed: ???

View File

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import time import time
from contextlib import nullcontext
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
@ -24,6 +25,7 @@ import torch
from datasets import concatenate_datasets from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars from datasets.utils import disable_progress_bars, enable_progress_bars
from omegaconf import DictConfig from omegaconf import DictConfig
from torch.cuda.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
@ -31,6 +33,7 @@ from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import PolicyWithUpdate from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
get_safe_torch_device, get_safe_torch_device,
@ -87,21 +90,40 @@ def make_optimizer_and_scheduler(cfg, policy):
return optimizer, lr_scheduler return optimizer, lr_scheduler
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): def update_policy(
policy,
batch,
optimizer,
grad_clip_norm,
grad_scaler: GradScaler,
lr_scheduler=None,
use_amp: bool = False,
):
"""Returns a dictionary of items for logging.""" """Returns a dictionary of items for logging."""
start_time = time.time() start_time = time.time()
device = get_device_from_parameters(policy)
policy.train() policy.train()
output_dict = policy.forward(batch) with torch.autocast(device_type=device.type) if use_amp else nullcontext():
# TODO(rcadene): policy.unnormalize_outputs(out_dict) output_dict = policy.forward(batch)
loss = output_dict["loss"] # TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss.backward() loss = output_dict["loss"]
grad_scaler.scale(loss).backward()
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(), policy.parameters(),
grad_clip_norm, grad_clip_norm,
error_if_nonfinite=False, error_if_nonfinite=False,
) )
optimizer.step() # Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if lr_scheduler is not None: if lr_scheduler is not None:
@ -320,7 +342,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.warning("eval.batch_size > 1 not supported for online training steps") logging.warning("eval.batch_size > 1 not supported for online training steps")
# Check device is available # Check device is available
get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -338,6 +360,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Create optimizer and scheduler # Create optimizer and scheduler
# Temporary hack to move optimizer out of policy # Temporary hack to move optimizer out of policy
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(enabled=cfg.use_amp)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
@ -358,14 +381,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
def evaluate_and_checkpoint_if_needed(step): def evaluate_and_checkpoint_if_needed(step):
if step % cfg.training.eval_freq == 0: if step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
eval_info = eval_policy( with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
eval_env, eval_info = eval_policy(
policy, eval_env,
cfg.eval.n_episodes, policy,
video_dir=Path(out_dir) / "eval", cfg.eval.n_episodes,
max_episodes_rendered=4, video_dir=Path(out_dir) / "eval",
start_seed=cfg.seed, max_episodes_rendered=4,
) start_seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
if cfg.wandb.enable: if cfg.wandb.enable:
logger.log_video(eval_info["video_paths"][0], step, mode="eval") logger.log_video(eval_info["video_paths"][0], step, mode="eval")
@ -389,7 +413,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
num_workers=4, num_workers=4,
batch_size=cfg.training.batch_size, batch_size=cfg.training.batch_size,
shuffle=True, shuffle=True,
pin_memory=cfg.device != "cpu", pin_memory=device.type != "cpu",
drop_last=False, drop_last=False,
) )
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
@ -403,9 +427,17 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
batch = next(dl_iter) batch = next(dl_iter)
for key in batch: for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True) batch[key] = batch[key].to(device, non_blocking=True)
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler) train_info = update_policy(
policy,
batch,
optimizer,
cfg.training.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.training.log_freq == 0: if step % cfg.training.log_freq == 0:
@ -436,7 +468,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
num_workers=4, num_workers=4,
batch_size=cfg.training.batch_size, batch_size=cfg.training.batch_size,
sampler=sampler, sampler=sampler,
pin_memory=cfg.device != "cpu", pin_memory=device.type != "cpu",
drop_last=False, drop_last=False,
) )
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
@ -448,7 +480,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("Start online training by interacting with environment") logging.info("Start online training by interacting with environment")
policy.eval() policy.eval()
with torch.no_grad(): with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
eval_info = eval_policy( eval_info = eval_policy(
online_training_env, online_training_env,
policy, policy,
@ -472,9 +504,17 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
batch = next(dl_iter) batch = next(dl_iter)
for key in batch: for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True) batch[key] = batch[key].to(device, non_blocking=True)
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler) train_info = update_policy(
policy,
batch,
optimizer,
cfg.training.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
)
if step % cfg.training.log_freq == 0: if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline) log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)