wip
This commit is contained in:
parent
096149b118
commit
304f83fb5c
|
@ -10,6 +10,9 @@ hydra:
|
|||
name: default
|
||||
|
||||
device: cuda # cpu
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: false
|
||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||
# AND for the evaluation environments.
|
||||
seed: ???
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -24,6 +25,7 @@ import torch
|
|||
from datasets import concatenate_datasets
|
||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||
from omegaconf import DictConfig
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
|
@ -31,6 +33,7 @@ from lerobot.common.envs.factory import make_env
|
|||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
|
@ -87,21 +90,40 @@ def make_optimizer_and_scheduler(cfg, policy):
|
|||
return optimizer, lr_scheduler
|
||||
|
||||
|
||||
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||
def update_policy(
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
grad_clip_norm,
|
||||
grad_scaler: GradScaler,
|
||||
lr_scheduler=None,
|
||||
use_amp: bool = False,
|
||||
):
|
||||
"""Returns a dictionary of items for logging."""
|
||||
start_time = time.time()
|
||||
device = get_device_from_parameters(policy)
|
||||
policy.train()
|
||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||
output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
loss = output_dict["loss"]
|
||||
loss.backward()
|
||||
grad_scaler.scale(loss).backward()
|
||||
|
||||
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
|
||||
grad_scaler.unscale_(optimizer)
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(),
|
||||
grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
optimizer.step()
|
||||
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
||||
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
||||
grad_scaler.step(optimizer)
|
||||
# Updates the scale for next iteration.
|
||||
grad_scaler.update()
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if lr_scheduler is not None:
|
||||
|
@ -320,7 +342,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
logging.warning("eval.batch_size > 1 not supported for online training steps")
|
||||
|
||||
# Check device is available
|
||||
get_safe_torch_device(cfg.device, log=True)
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
@ -338,6 +360,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
# Create optimizer and scheduler
|
||||
# Temporary hack to move optimizer out of policy
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
grad_scaler = GradScaler(enabled=cfg.use_amp)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
@ -358,6 +381,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
def evaluate_and_checkpoint_if_needed(step):
|
||||
if step % cfg.training.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
|
@ -389,7 +413,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
num_workers=4,
|
||||
batch_size=cfg.training.batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=cfg.device != "cpu",
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
@ -403,9 +427,17 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
batch = next(dl_iter)
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
|
||||
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
|
||||
train_info = update_policy(
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
cfg.training.grad_clip_norm,
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.use_amp,
|
||||
)
|
||||
|
||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||
if step % cfg.training.log_freq == 0:
|
||||
|
@ -436,7 +468,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
num_workers=4,
|
||||
batch_size=cfg.training.batch_size,
|
||||
sampler=sampler,
|
||||
pin_memory=cfg.device != "cpu",
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
@ -448,7 +480,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
logging.info("Start online training by interacting with environment")
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||
eval_info = eval_policy(
|
||||
online_training_env,
|
||||
policy,
|
||||
|
@ -472,9 +504,17 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
batch = next(dl_iter)
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
|
||||
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
|
||||
train_info = update_policy(
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
cfg.training.grad_clip_norm,
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.use_amp,
|
||||
)
|
||||
|
||||
if step % cfg.training.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
|
||||
|
|
Loading…
Reference in New Issue