From a281f5c2e0439af95ddff7bbf1330f1fbc3a6a28 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 20 May 2024 17:55:01 +0100 Subject: [PATCH] ready for review --- Makefile | 3 +- lerobot/configs/default.yaml | 1 + lerobot/configs/policy/tdmpc.yaml | 2 +- lerobot/scripts/train.py | 162 +----------------------------- 4 files changed, 8 insertions(+), 160 deletions(-) diff --git a/Makefile b/Makefile index a0163f94..9df4fe60 100644 --- a/Makefile +++ b/Makefile @@ -74,6 +74,7 @@ test-diffusion-ete-eval: env.episode_length=8 \ device=cpu \ +# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated. test-tdmpc-ete-train: python lerobot/scripts/train.py \ policy=tdmpc \ @@ -82,7 +83,7 @@ test-tdmpc-ete-train: dataset_repo_id=lerobot/xarm_lift_medium \ wandb.enable=False \ training.offline_steps=2 \ - training.online_steps=2 \ + training.online_steps=0 \ eval.n_episodes=1 \ eval.batch_size=1 \ env.episode_length=2 \ diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index ae36b3e2..e35deba1 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -17,6 +17,7 @@ dataset_repo_id: lerobot/pusht training: offline_steps: ??? + # NOTE: `online_steps` is not implemented yet. It's here as a placeholder. online_steps: ??? online_steps_between_rollouts: ??? online_sampling_ratio: 0.5 diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 7e736850..d2585494 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -5,7 +5,7 @@ dataset_repo_id: lerobot/xarm_lift_medium training: offline_steps: 25000 - online_steps: 25000 + online_steps: 0 # 25000 not implemented yet eval_freq: 5000 online_steps_between_rollouts: 1 online_sampling_ratio: 0.5 diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 7ca7a0b3..58a2bd01 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -18,11 +18,8 @@ import time from copy import deepcopy from pathlib import Path -import datasets import hydra import torch -from datasets import concatenate_datasets -from datasets.utils import disable_progress_bars, enable_progress_bars from omegaconf import DictConfig from lerobot.common.datasets.factory import make_dataset @@ -69,7 +66,6 @@ def make_optimizer_and_scheduler(cfg, policy): cfg.training.adam_eps, cfg.training.adam_weight_decay, ) - assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training." from diffusers.optimization import get_scheduler lr_scheduler = get_scheduler( @@ -211,103 +207,6 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline): logger.log_dict(info, step, mode="eval") -def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float): - """ - Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average). - - Parameters: - - n_off (int): Number of offline samples, each with a sampling weight of 1. - - n_on (int): Number of online samples. - - pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5). - - The total weight of offline samples is n_off * 1.0. - The total weight of offline samples is n_on * w. - The total combined weight of all samples is n_off + n_on * w. - The fraction of the weight that is online is n_on * w / (n_off + n_on * w). - We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on. - The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1)) - """ - assert 0.0 <= pc_on <= 1.0 - return -(n_off * pc_on) / (n_on * (pc_on - 1)) - - -def add_episodes_inplace( - online_dataset: torch.utils.data.Dataset, - concat_dataset: torch.utils.data.ConcatDataset, - sampler: torch.utils.data.WeightedRandomSampler, - hf_dataset: datasets.Dataset, - episode_data_index: dict[str, torch.Tensor], - pc_online_samples: float, -): - """ - Modifies the online_dataset, concat_dataset, and sampler in place by integrating - new episodes from hf_dataset into the online_dataset, updating the concatenated - dataset's structure and adjusting the sampling strategy based on the specified - percentage of online samples. - - Parameters: - - online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated. - - concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines - offline and online datasets, used for sampling purposes. - - sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to - reflect changes in the dataset sizes and specified sampling weights. - - hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added. - - episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices. - They indicate the start index and end index of each episode in the dataset. - - pc_online_samples (float): The target percentage of samples that should come from - the online dataset during sampling operations. - - Raises: - - AssertionError: If the first episode_id or index in hf_dataset is not 0 - """ - first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item() - last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item() - first_index = hf_dataset.select_columns("index")[0]["index"].item() - last_index = hf_dataset.select_columns("index")[-1]["index"].item() - # sanity check - assert first_episode_idx == 0, f"{first_episode_idx=} is not 0" - assert first_index == 0, f"{first_index=} is not 0" - assert first_index == episode_data_index["from"][first_episode_idx].item() - assert last_index == episode_data_index["to"][last_episode_idx].item() - 1 - - if len(online_dataset) == 0: - # initialize online dataset - online_dataset.hf_dataset = hf_dataset - online_dataset.episode_data_index = episode_data_index - else: - # get the starting indices of the new episodes and frames to be added - start_episode_idx = last_episode_idx + 1 - start_index = last_index + 1 - - def shift_indices(episode_index, index): - # note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to - example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index} - return example - - disable_progress_bars() # map has a tqdm progress bar - hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"]) - enable_progress_bars() - - episode_data_index["from"] += start_index - episode_data_index["to"] += start_index - - # extend online dataset - online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset]) - - # update the concatenated dataset length used during sampling - concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets) - - # update the sampling weights for each frame so that online frames get sampled a certain percentage of times - len_online = len(online_dataset) - len_offline = len(concat_dataset) - len_online - weight_offline = 1.0 - weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples) - sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset)) - - # update the total number of samples used during sampling - sampler.num_samples = len(concat_dataset) - - def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): if out_dir is None: raise NotImplementedError() @@ -316,8 +215,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No init_logging() - if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1: - logging.warning("eval.batch_size > 1 not supported for online training steps") + if cfg.training.online_steps > 0: + raise NotImplementedError("Online training is not implemented yet.") # Check device is available get_safe_torch_device(cfg.device, log=True) @@ -395,10 +294,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No dl_iter = cycle(dataloader) policy.train() - step = 0 # number of policy update (forward + backward + optim) is_offline = True - for offline_step in range(cfg.training.offline_steps): - if offline_step == 0: + for step in range(cfg.training.offline_steps): + if step == 0: logging.info("Start offline training on a fixed dataset") batch = next(dl_iter) @@ -415,11 +313,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # so we pass in step + 1. evaluate_and_checkpoint_if_needed(step + 1) - step += 1 - - # create an env dedicated to online episodes collection from policy rollout - online_training_env = make_env(cfg, n_envs=1) - # create an empty online dataset similar to offline dataset online_dataset = deepcopy(offline_dataset) online_dataset.hf_dataset = {} @@ -439,55 +332,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No pin_memory=cfg.device != "cpu", drop_last=False, ) - dl_iter = cycle(dataloader) - - online_step = 0 - is_offline = False - for env_step in range(cfg.training.online_steps): - if env_step == 0: - logging.info("Start online training by interacting with environment") - - policy.eval() - with torch.no_grad(): - eval_info = eval_policy( - online_training_env, - policy, - n_episodes=1, - return_episode_data=True, - start_seed=cfg.training.online_env_seed, - enable_progbar=True, - ) - - add_episodes_inplace( - online_dataset, - concat_dataset, - sampler, - hf_dataset=eval_info["episodes"]["hf_dataset"], - episode_data_index=eval_info["episodes"]["episode_data_index"], - pc_online_samples=cfg.training.online_sampling_ratio, - ) - - policy.train() - for _ in range(cfg.training.online_steps_between_rollouts): - batch = next(dl_iter) - - for key in batch: - batch[key] = batch[key].to(cfg.device, non_blocking=True) - - train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler) - - if step % cfg.training.log_freq == 0: - log_train_info(logger, train_info, step, cfg, online_dataset, is_offline) - - # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed, - # so we pass in step + 1. - evaluate_and_checkpoint_if_needed(step + 1) - - step += 1 - online_step += 1 eval_env.close() - online_training_env.close() logging.info("End of training")