From 2b270d085bdc2e7c64920422537f983cc60a095c Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 20 May 2024 18:27:54 +0100 Subject: [PATCH 1/2] Disable online training (#202) Co-authored-by: Remi --- Makefile | 3 +- lerobot/configs/default.yaml | 1 + lerobot/configs/policy/tdmpc.yaml | 3 +- lerobot/scripts/train.py | 162 +----------------------------- 4 files changed, 9 insertions(+), 160 deletions(-) diff --git a/Makefile b/Makefile index a0163f94..9df4fe60 100644 --- a/Makefile +++ b/Makefile @@ -74,6 +74,7 @@ test-diffusion-ete-eval: env.episode_length=8 \ device=cpu \ +# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated. test-tdmpc-ete-train: python lerobot/scripts/train.py \ policy=tdmpc \ @@ -82,7 +83,7 @@ test-tdmpc-ete-train: dataset_repo_id=lerobot/xarm_lift_medium \ wandb.enable=False \ training.offline_steps=2 \ - training.online_steps=2 \ + training.online_steps=0 \ eval.n_episodes=1 \ eval.batch_size=1 \ env.episode_length=2 \ diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index ae36b3e2..e35deba1 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -17,6 +17,7 @@ dataset_repo_id: lerobot/pusht training: offline_steps: ??? + # NOTE: `online_steps` is not implemented yet. It's here as a placeholder. online_steps: ??? online_steps_between_rollouts: ??? online_sampling_ratio: 0.5 diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 7e736850..09326ab4 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -5,7 +5,8 @@ dataset_repo_id: lerobot/xarm_lift_medium training: offline_steps: 25000 - online_steps: 25000 + # TODO(alexander-soare): uncomment when online training gets reinstated + online_steps: 0 # 25000 not implemented yet eval_freq: 5000 online_steps_between_rollouts: 1 online_sampling_ratio: 0.5 diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 7ca7a0b3..58a2bd01 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -18,11 +18,8 @@ import time from copy import deepcopy from pathlib import Path -import datasets import hydra import torch -from datasets import concatenate_datasets -from datasets.utils import disable_progress_bars, enable_progress_bars from omegaconf import DictConfig from lerobot.common.datasets.factory import make_dataset @@ -69,7 +66,6 @@ def make_optimizer_and_scheduler(cfg, policy): cfg.training.adam_eps, cfg.training.adam_weight_decay, ) - assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training." from diffusers.optimization import get_scheduler lr_scheduler = get_scheduler( @@ -211,103 +207,6 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline): logger.log_dict(info, step, mode="eval") -def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float): - """ - Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average). - - Parameters: - - n_off (int): Number of offline samples, each with a sampling weight of 1. - - n_on (int): Number of online samples. - - pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5). - - The total weight of offline samples is n_off * 1.0. - The total weight of offline samples is n_on * w. - The total combined weight of all samples is n_off + n_on * w. - The fraction of the weight that is online is n_on * w / (n_off + n_on * w). - We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on. - The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1)) - """ - assert 0.0 <= pc_on <= 1.0 - return -(n_off * pc_on) / (n_on * (pc_on - 1)) - - -def add_episodes_inplace( - online_dataset: torch.utils.data.Dataset, - concat_dataset: torch.utils.data.ConcatDataset, - sampler: torch.utils.data.WeightedRandomSampler, - hf_dataset: datasets.Dataset, - episode_data_index: dict[str, torch.Tensor], - pc_online_samples: float, -): - """ - Modifies the online_dataset, concat_dataset, and sampler in place by integrating - new episodes from hf_dataset into the online_dataset, updating the concatenated - dataset's structure and adjusting the sampling strategy based on the specified - percentage of online samples. - - Parameters: - - online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated. - - concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines - offline and online datasets, used for sampling purposes. - - sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to - reflect changes in the dataset sizes and specified sampling weights. - - hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added. - - episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices. - They indicate the start index and end index of each episode in the dataset. - - pc_online_samples (float): The target percentage of samples that should come from - the online dataset during sampling operations. - - Raises: - - AssertionError: If the first episode_id or index in hf_dataset is not 0 - """ - first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item() - last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item() - first_index = hf_dataset.select_columns("index")[0]["index"].item() - last_index = hf_dataset.select_columns("index")[-1]["index"].item() - # sanity check - assert first_episode_idx == 0, f"{first_episode_idx=} is not 0" - assert first_index == 0, f"{first_index=} is not 0" - assert first_index == episode_data_index["from"][first_episode_idx].item() - assert last_index == episode_data_index["to"][last_episode_idx].item() - 1 - - if len(online_dataset) == 0: - # initialize online dataset - online_dataset.hf_dataset = hf_dataset - online_dataset.episode_data_index = episode_data_index - else: - # get the starting indices of the new episodes and frames to be added - start_episode_idx = last_episode_idx + 1 - start_index = last_index + 1 - - def shift_indices(episode_index, index): - # note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to - example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index} - return example - - disable_progress_bars() # map has a tqdm progress bar - hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"]) - enable_progress_bars() - - episode_data_index["from"] += start_index - episode_data_index["to"] += start_index - - # extend online dataset - online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset]) - - # update the concatenated dataset length used during sampling - concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets) - - # update the sampling weights for each frame so that online frames get sampled a certain percentage of times - len_online = len(online_dataset) - len_offline = len(concat_dataset) - len_online - weight_offline = 1.0 - weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples) - sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset)) - - # update the total number of samples used during sampling - sampler.num_samples = len(concat_dataset) - - def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): if out_dir is None: raise NotImplementedError() @@ -316,8 +215,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No init_logging() - if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1: - logging.warning("eval.batch_size > 1 not supported for online training steps") + if cfg.training.online_steps > 0: + raise NotImplementedError("Online training is not implemented yet.") # Check device is available get_safe_torch_device(cfg.device, log=True) @@ -395,10 +294,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No dl_iter = cycle(dataloader) policy.train() - step = 0 # number of policy update (forward + backward + optim) is_offline = True - for offline_step in range(cfg.training.offline_steps): - if offline_step == 0: + for step in range(cfg.training.offline_steps): + if step == 0: logging.info("Start offline training on a fixed dataset") batch = next(dl_iter) @@ -415,11 +313,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # so we pass in step + 1. evaluate_and_checkpoint_if_needed(step + 1) - step += 1 - - # create an env dedicated to online episodes collection from policy rollout - online_training_env = make_env(cfg, n_envs=1) - # create an empty online dataset similar to offline dataset online_dataset = deepcopy(offline_dataset) online_dataset.hf_dataset = {} @@ -439,55 +332,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No pin_memory=cfg.device != "cpu", drop_last=False, ) - dl_iter = cycle(dataloader) - - online_step = 0 - is_offline = False - for env_step in range(cfg.training.online_steps): - if env_step == 0: - logging.info("Start online training by interacting with environment") - - policy.eval() - with torch.no_grad(): - eval_info = eval_policy( - online_training_env, - policy, - n_episodes=1, - return_episode_data=True, - start_seed=cfg.training.online_env_seed, - enable_progbar=True, - ) - - add_episodes_inplace( - online_dataset, - concat_dataset, - sampler, - hf_dataset=eval_info["episodes"]["hf_dataset"], - episode_data_index=eval_info["episodes"]["episode_data_index"], - pc_online_samples=cfg.training.online_sampling_ratio, - ) - - policy.train() - for _ in range(cfg.training.online_steps_between_rollouts): - batch = next(dl_iter) - - for key in batch: - batch[key] = batch[key].to(cfg.device, non_blocking=True) - - train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler) - - if step % cfg.training.log_freq == 0: - log_train_info(logger, train_info, step, cfg, online_dataset, is_offline) - - # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed, - # so we pass in step + 1. - evaluate_and_checkpoint_if_needed(step + 1) - - step += 1 - online_step += 1 eval_env.close() - online_training_env.close() logging.info("End of training") From b6c216b5902f532f6066155bc4c99a187a2d3414 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 20 May 2024 18:57:54 +0100 Subject: [PATCH 2/2] Add Automatic Mixed Precision option for training and evaluation. (#199) --- Makefile | 35 ++++++++++++++++- lerobot/configs/default.yaml | 3 ++ lerobot/scripts/eval.py | 24 ++++++------ lerobot/scripts/train.py | 74 ++++++++++++++++++++++++++---------- 4 files changed, 103 insertions(+), 33 deletions(-) diff --git a/Makefile b/Makefile index 9df4fe60..9a8a2474 100644 --- a/Makefile +++ b/Makefile @@ -20,6 +20,8 @@ build-gpu: test-end-to-end: ${MAKE} test-act-ete-train ${MAKE} test-act-ete-eval + ${MAKE} test-act-ete-train-amp + ${MAKE} test-act-ete-eval-amp ${MAKE} test-diffusion-ete-train ${MAKE} test-diffusion-ete-eval ${MAKE} test-tdmpc-ete-train @@ -29,6 +31,7 @@ test-end-to-end: test-act-ete-train: python lerobot/scripts/train.py \ policy=act \ + policy.dim_model=64 \ env=aloha \ wandb.enable=False \ training.offline_steps=2 \ @@ -51,9 +54,40 @@ test-act-ete-eval: env.episode_length=8 \ device=cpu \ +test-act-ete-train-amp: + python lerobot/scripts/train.py \ + policy=act \ + policy.dim_model=64 \ + env=aloha \ + wandb.enable=False \ + training.offline_steps=2 \ + training.online_steps=0 \ + eval.n_episodes=1 \ + eval.batch_size=1 \ + device=cpu \ + training.save_model=true \ + training.save_freq=2 \ + policy.n_action_steps=20 \ + policy.chunk_size=20 \ + training.batch_size=2 \ + hydra.run.dir=tests/outputs/act/ \ + use_amp=true + +test-act-ete-eval-amp: + python lerobot/scripts/eval.py \ + -p tests/outputs/act/checkpoints/000002 \ + eval.n_episodes=1 \ + eval.batch_size=1 \ + env.episode_length=8 \ + device=cpu \ + use_amp=true + test-diffusion-ete-train: python lerobot/scripts/train.py \ policy=diffusion \ + policy.down_dims=\[64,128,256\] \ + policy.diffusion_step_embed_dim=32 \ + policy.num_inference_steps=10 \ env=pusht \ wandb.enable=False \ training.offline_steps=2 \ @@ -101,7 +135,6 @@ test-tdmpc-ete-eval: env.episode_length=8 \ device=cpu \ - test-default-ete-eval: python lerobot/scripts/eval.py \ --config lerobot/configs/default.yaml \ diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index e35deba1..27329bc9 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -10,6 +10,9 @@ hydra: name: default device: cuda # cpu +# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, +# automatic gradient scaling is used. +use_amp: false # `seed` is used for training (eg: model initialization, dataset shuffling) # AND for the evaluation environments. seed: ??? diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 9c95633a..7e4690d0 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -46,6 +46,7 @@ import json import logging import threading import time +from contextlib import nullcontext from copy import deepcopy from datetime import datetime as dt from pathlib import Path @@ -520,7 +521,7 @@ def eval( raise NotImplementedError() # Check device is available - get_safe_torch_device(hydra_cfg.device, log=True) + device = get_safe_torch_device(hydra_cfg.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -539,16 +540,17 @@ def eval( policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats) policy.eval() - info = eval_policy( - env, - policy, - hydra_cfg.eval.n_episodes, - max_episodes_rendered=10, - video_dir=Path(out_dir) / "eval", - start_seed=hydra_cfg.seed, - enable_progbar=True, - enable_inner_progbar=True, - ) + with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(): + info = eval_policy( + env, + policy, + hydra_cfg.eval.n_episodes, + max_episodes_rendered=10, + video_dir=Path(out_dir) / "eval", + start_seed=hydra_cfg.seed, + enable_progbar=True, + enable_inner_progbar=True, + ) print(info["aggregated"]) # Save info diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 58a2bd01..2b28943d 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -15,12 +15,14 @@ # limitations under the License. import logging import time +from contextlib import nullcontext from copy import deepcopy from pathlib import Path import hydra import torch from omegaconf import DictConfig +from torch.cuda.amp import GradScaler from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle @@ -28,6 +30,7 @@ from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy from lerobot.common.policies.policy_protocol import PolicyWithUpdate +from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.utils.utils import ( format_big_number, get_safe_torch_device, @@ -83,21 +86,40 @@ def make_optimizer_and_scheduler(cfg, policy): return optimizer, lr_scheduler -def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): +def update_policy( + policy, + batch, + optimizer, + grad_clip_norm, + grad_scaler: GradScaler, + lr_scheduler=None, + use_amp: bool = False, +): """Returns a dictionary of items for logging.""" - start_time = time.time() + start_time = time.perf_counter() + device = get_device_from_parameters(policy) policy.train() - output_dict = policy.forward(batch) - # TODO(rcadene): policy.unnormalize_outputs(out_dict) - loss = output_dict["loss"] - loss.backward() + with torch.autocast(device_type=device.type) if use_amp else nullcontext(): + output_dict = policy.forward(batch) + # TODO(rcadene): policy.unnormalize_outputs(out_dict) + loss = output_dict["loss"] + grad_scaler.scale(loss).backward() + + # Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**. + grad_scaler.unscale_(optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_( policy.parameters(), grad_clip_norm, error_if_nonfinite=False, ) - optimizer.step() + # Optimizer's gradients are already unscaled, so scaler.step does not unscale them, + # although it still skips optimizer.step() if the gradients contain infs or NaNs. + grad_scaler.step(optimizer) + # Updates the scale for next iteration. + grad_scaler.update() + optimizer.zero_grad() if lr_scheduler is not None: @@ -111,7 +133,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): "loss": loss.item(), "grad_norm": float(grad_norm), "lr": optimizer.param_groups[0]["lr"], - "update_s": time.time() - start_time, + "update_s": time.perf_counter() - start_time, **{k: v for k, v in output_dict.items() if k != "loss"}, } @@ -219,7 +241,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No raise NotImplementedError("Online training is not implemented yet.") # Check device is available - get_safe_torch_device(cfg.device, log=True) + device = get_safe_torch_device(cfg.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -237,6 +259,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # Create optimizer and scheduler # Temporary hack to move optimizer out of policy optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) + grad_scaler = GradScaler(enabled=cfg.use_amp) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) @@ -257,14 +280,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No def evaluate_and_checkpoint_if_needed(step): if step % cfg.training.eval_freq == 0: logging.info(f"Eval policy at step {step}") - eval_info = eval_policy( - eval_env, - policy, - cfg.eval.n_episodes, - video_dir=Path(out_dir) / "eval", - max_episodes_rendered=4, - start_seed=cfg.seed, - ) + with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): + eval_info = eval_policy( + eval_env, + policy, + cfg.eval.n_episodes, + video_dir=Path(out_dir) / "eval", + max_episodes_rendered=4, + start_seed=cfg.seed, + ) log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) if cfg.wandb.enable: logger.log_video(eval_info["video_paths"][0], step, mode="eval") @@ -288,7 +312,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No num_workers=4, batch_size=cfg.training.batch_size, shuffle=True, - pin_memory=cfg.device != "cpu", + pin_memory=device.type != "cpu", drop_last=False, ) dl_iter = cycle(dataloader) @@ -301,9 +325,17 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No batch = next(dl_iter) for key in batch: - batch[key] = batch[key].to(cfg.device, non_blocking=True) + batch[key] = batch[key].to(device, non_blocking=True) - train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler) + train_info = update_policy( + policy, + batch, + optimizer, + cfg.training.grad_clip_norm, + grad_scaler=grad_scaler, + lr_scheduler=lr_scheduler, + use_amp=cfg.use_amp, + ) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.training.log_freq == 0: @@ -329,7 +361,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No num_workers=4, batch_size=cfg.training.batch_size, sampler=sampler, - pin_memory=cfg.device != "cpu", + pin_memory=device.type != "cpu", drop_last=False, )