Add Automatic Mixed Precision option for training and evaluation. (#199)
This commit is contained in:
parent
2b270d085b
commit
b6c216b590
35
Makefile
35
Makefile
|
@ -20,6 +20,8 @@ build-gpu:
|
||||||
test-end-to-end:
|
test-end-to-end:
|
||||||
${MAKE} test-act-ete-train
|
${MAKE} test-act-ete-train
|
||||||
${MAKE} test-act-ete-eval
|
${MAKE} test-act-ete-eval
|
||||||
|
${MAKE} test-act-ete-train-amp
|
||||||
|
${MAKE} test-act-ete-eval-amp
|
||||||
${MAKE} test-diffusion-ete-train
|
${MAKE} test-diffusion-ete-train
|
||||||
${MAKE} test-diffusion-ete-eval
|
${MAKE} test-diffusion-ete-eval
|
||||||
${MAKE} test-tdmpc-ete-train
|
${MAKE} test-tdmpc-ete-train
|
||||||
|
@ -29,6 +31,7 @@ test-end-to-end:
|
||||||
test-act-ete-train:
|
test-act-ete-train:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
policy=act \
|
policy=act \
|
||||||
|
policy.dim_model=64 \
|
||||||
env=aloha \
|
env=aloha \
|
||||||
wandb.enable=False \
|
wandb.enable=False \
|
||||||
training.offline_steps=2 \
|
training.offline_steps=2 \
|
||||||
|
@ -51,9 +54,40 @@ test-act-ete-eval:
|
||||||
env.episode_length=8 \
|
env.episode_length=8 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
|
|
||||||
|
test-act-ete-train-amp:
|
||||||
|
python lerobot/scripts/train.py \
|
||||||
|
policy=act \
|
||||||
|
policy.dim_model=64 \
|
||||||
|
env=aloha \
|
||||||
|
wandb.enable=False \
|
||||||
|
training.offline_steps=2 \
|
||||||
|
training.online_steps=0 \
|
||||||
|
eval.n_episodes=1 \
|
||||||
|
eval.batch_size=1 \
|
||||||
|
device=cpu \
|
||||||
|
training.save_model=true \
|
||||||
|
training.save_freq=2 \
|
||||||
|
policy.n_action_steps=20 \
|
||||||
|
policy.chunk_size=20 \
|
||||||
|
training.batch_size=2 \
|
||||||
|
hydra.run.dir=tests/outputs/act/ \
|
||||||
|
use_amp=true
|
||||||
|
|
||||||
|
test-act-ete-eval-amp:
|
||||||
|
python lerobot/scripts/eval.py \
|
||||||
|
-p tests/outputs/act/checkpoints/000002 \
|
||||||
|
eval.n_episodes=1 \
|
||||||
|
eval.batch_size=1 \
|
||||||
|
env.episode_length=8 \
|
||||||
|
device=cpu \
|
||||||
|
use_amp=true
|
||||||
|
|
||||||
test-diffusion-ete-train:
|
test-diffusion-ete-train:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
policy=diffusion \
|
policy=diffusion \
|
||||||
|
policy.down_dims=\[64,128,256\] \
|
||||||
|
policy.diffusion_step_embed_dim=32 \
|
||||||
|
policy.num_inference_steps=10 \
|
||||||
env=pusht \
|
env=pusht \
|
||||||
wandb.enable=False \
|
wandb.enable=False \
|
||||||
training.offline_steps=2 \
|
training.offline_steps=2 \
|
||||||
|
@ -101,7 +135,6 @@ test-tdmpc-ete-eval:
|
||||||
env.episode_length=8 \
|
env.episode_length=8 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
|
|
||||||
|
|
||||||
test-default-ete-eval:
|
test-default-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
--config lerobot/configs/default.yaml \
|
--config lerobot/configs/default.yaml \
|
||||||
|
|
|
@ -10,6 +10,9 @@ hydra:
|
||||||
name: default
|
name: default
|
||||||
|
|
||||||
device: cuda # cpu
|
device: cuda # cpu
|
||||||
|
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||||
|
# automatic gradient scaling is used.
|
||||||
|
use_amp: false
|
||||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||||
# AND for the evaluation environments.
|
# AND for the evaluation environments.
|
||||||
seed: ???
|
seed: ???
|
||||||
|
|
|
@ -46,6 +46,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime as dt
|
from datetime import datetime as dt
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -520,7 +521,7 @@ def eval(
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
get_safe_torch_device(hydra_cfg.device, log=True)
|
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
@ -539,16 +540,17 @@ def eval(
|
||||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
info = eval_policy(
|
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
||||||
env,
|
info = eval_policy(
|
||||||
policy,
|
env,
|
||||||
hydra_cfg.eval.n_episodes,
|
policy,
|
||||||
max_episodes_rendered=10,
|
hydra_cfg.eval.n_episodes,
|
||||||
video_dir=Path(out_dir) / "eval",
|
max_episodes_rendered=10,
|
||||||
start_seed=hydra_cfg.seed,
|
video_dir=Path(out_dir) / "eval",
|
||||||
enable_progbar=True,
|
start_seed=hydra_cfg.seed,
|
||||||
enable_inner_progbar=True,
|
enable_progbar=True,
|
||||||
)
|
enable_inner_progbar=True,
|
||||||
|
)
|
||||||
print(info["aggregated"])
|
print(info["aggregated"])
|
||||||
|
|
||||||
# Save info
|
# Save info
|
||||||
|
|
|
@ -15,12 +15,14 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
|
from torch.cuda.amp import GradScaler
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
|
@ -28,6 +30,7 @@ from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import Logger, log_output_dir
|
from lerobot.common.logger import Logger, log_output_dir
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
|
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
|
||||||
|
from lerobot.common.policies.utils import get_device_from_parameters
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
format_big_number,
|
format_big_number,
|
||||||
get_safe_torch_device,
|
get_safe_torch_device,
|
||||||
|
@ -83,21 +86,40 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||||
return optimizer, lr_scheduler
|
return optimizer, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
def update_policy(
|
||||||
|
policy,
|
||||||
|
batch,
|
||||||
|
optimizer,
|
||||||
|
grad_clip_norm,
|
||||||
|
grad_scaler: GradScaler,
|
||||||
|
lr_scheduler=None,
|
||||||
|
use_amp: bool = False,
|
||||||
|
):
|
||||||
"""Returns a dictionary of items for logging."""
|
"""Returns a dictionary of items for logging."""
|
||||||
start_time = time.time()
|
start_time = time.perf_counter()
|
||||||
|
device = get_device_from_parameters(policy)
|
||||||
policy.train()
|
policy.train()
|
||||||
output_dict = policy.forward(batch)
|
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
output_dict = policy.forward(batch)
|
||||||
loss = output_dict["loss"]
|
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||||
loss.backward()
|
loss = output_dict["loss"]
|
||||||
|
grad_scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
|
||||||
|
grad_scaler.unscale_(optimizer)
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
policy.parameters(),
|
policy.parameters(),
|
||||||
grad_clip_norm,
|
grad_clip_norm,
|
||||||
error_if_nonfinite=False,
|
error_if_nonfinite=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer.step()
|
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
||||||
|
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
||||||
|
grad_scaler.step(optimizer)
|
||||||
|
# Updates the scale for next iteration.
|
||||||
|
grad_scaler.update()
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if lr_scheduler is not None:
|
if lr_scheduler is not None:
|
||||||
|
@ -111,7 +133,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||||
"loss": loss.item(),
|
"loss": loss.item(),
|
||||||
"grad_norm": float(grad_norm),
|
"grad_norm": float(grad_norm),
|
||||||
"lr": optimizer.param_groups[0]["lr"],
|
"lr": optimizer.param_groups[0]["lr"],
|
||||||
"update_s": time.time() - start_time,
|
"update_s": time.perf_counter() - start_time,
|
||||||
**{k: v for k, v in output_dict.items() if k != "loss"},
|
**{k: v for k, v in output_dict.items() if k != "loss"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -219,7 +241,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
raise NotImplementedError("Online training is not implemented yet.")
|
raise NotImplementedError("Online training is not implemented yet.")
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
get_safe_torch_device(cfg.device, log=True)
|
device = get_safe_torch_device(cfg.device, log=True)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
@ -237,6 +259,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
# Temporary hack to move optimizer out of policy
|
# Temporary hack to move optimizer out of policy
|
||||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||||
|
grad_scaler = GradScaler(enabled=cfg.use_amp)
|
||||||
|
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
@ -257,14 +280,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
def evaluate_and_checkpoint_if_needed(step):
|
def evaluate_and_checkpoint_if_needed(step):
|
||||||
if step % cfg.training.eval_freq == 0:
|
if step % cfg.training.eval_freq == 0:
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
eval_info = eval_policy(
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||||
eval_env,
|
eval_info = eval_policy(
|
||||||
policy,
|
eval_env,
|
||||||
cfg.eval.n_episodes,
|
policy,
|
||||||
video_dir=Path(out_dir) / "eval",
|
cfg.eval.n_episodes,
|
||||||
max_episodes_rendered=4,
|
video_dir=Path(out_dir) / "eval",
|
||||||
start_seed=cfg.seed,
|
max_episodes_rendered=4,
|
||||||
)
|
start_seed=cfg.seed,
|
||||||
|
)
|
||||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
||||||
if cfg.wandb.enable:
|
if cfg.wandb.enable:
|
||||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||||
|
@ -288,7 +312,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=cfg.training.batch_size,
|
batch_size=cfg.training.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=cfg.device != "cpu",
|
pin_memory=device.type != "cpu",
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
)
|
)
|
||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
|
@ -301,9 +325,17 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
batch = next(dl_iter)
|
batch = next(dl_iter)
|
||||||
|
|
||||||
for key in batch:
|
for key in batch:
|
||||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
batch[key] = batch[key].to(device, non_blocking=True)
|
||||||
|
|
||||||
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
|
train_info = update_policy(
|
||||||
|
policy,
|
||||||
|
batch,
|
||||||
|
optimizer,
|
||||||
|
cfg.training.grad_clip_norm,
|
||||||
|
grad_scaler=grad_scaler,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
use_amp=cfg.use_amp,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||||
if step % cfg.training.log_freq == 0:
|
if step % cfg.training.log_freq == 0:
|
||||||
|
@ -329,7 +361,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=cfg.training.batch_size,
|
batch_size=cfg.training.batch_size,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=cfg.device != "cpu",
|
pin_memory=device.type != "cpu",
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue