From c4da6891719da44139e321cf13327a8773822dde Mon Sep 17 00:00:00 2001 From: Remi Date: Mon, 20 May 2024 18:30:11 +0200 Subject: [PATCH 1/2] Hot fix to compute validation loss example test (#200) Co-authored-by: Alexander Soare --- examples/4_calculate_validation_loss.py | 27 +++++++++++++++++++++---- lerobot/common/datasets/utils.py | 27 ++++++++++++++++++++++++- tests/test_examples.py | 2 +- 3 files changed, 50 insertions(+), 6 deletions(-) diff --git a/examples/4_calculate_validation_loss.py b/examples/4_calculate_validation_loss.py index 285184d2..1428014b 100644 --- a/examples/4_calculate_validation_loss.py +++ b/examples/4_calculate_validation_loss.py @@ -8,6 +8,7 @@ especially in the context of imitation learning. The most reliable approach is t on the target environment, whether that be in simulation or the real world. """ +import math from pathlib import Path import torch @@ -39,11 +40,29 @@ delta_timestamps = { "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], } -# Load the last 10 episodes of the dataset as a validation set. -# The `split` argument utilizes the `datasets` library's syntax for slicing datasets. -# For more information on the Slice API, please see: +# Load the last 10% of episodes of the dataset as a validation set. +# - Load full dataset +full_dataset = LeRobotDataset("lerobot/pusht", split="train") +# - Calculate train and val subsets +num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100) +num_val_episodes = full_dataset.num_episodes - num_train_episodes +print(f"Number of episodes in full dataset: {full_dataset.num_episodes}") +print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}") +print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}") +# - Get first frame index of the validation set +first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item() +# - Load frames subset belonging to validation set using the `split` argument. +# It utilizes the `datasets` library's syntax for slicing datasets. +# For more information on the Slice API, please see: # https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits -val_dataset = LeRobotDataset("lerobot/pusht", split="train[24342:]", delta_timestamps=delta_timestamps) +train_dataset = LeRobotDataset( + "lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps +) +val_dataset = LeRobotDataset( + "lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps +) +print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}") +print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}") # Create dataloader for evaluation. val_dataloader = torch.utils.data.DataLoader( diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 207ccf7c..86fef8d4 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import re from pathlib import Path from typing import Dict @@ -80,7 +81,23 @@ def hf_transform_to_torch(items_dict): def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" if root is not None: - hf_dataset = load_from_disk(str(Path(root) / repo_id / split)) + hf_dataset = load_from_disk(str(Path(root) / repo_id / "train")) + # TODO(rcadene): clean this which enables getting a subset of dataset + if split != "train": + if "%" in split: + raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).") + match_from = re.search(r"train\[(\d+):\]", split) + match_to = re.search(r"train\[:(\d+)\]", split) + if match_from: + from_frame_index = int(match_from.group(1)) + hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset))) + elif match_to: + to_frame_index = int(match_to.group(1)) + hf_dataset = hf_dataset.select(range(to_frame_index)) + else: + raise ValueError( + f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"' + ) else: hf_dataset = load_dataset(repo_id, revision=version, split=split) hf_dataset.set_transform(hf_transform_to_torch) @@ -273,6 +290,12 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc "to": [3, 7, 12] } """ + if len(hf_dataset) == 0: + episode_data_index = { + "from": torch.tensor([]), + "to": torch.tensor([]), + } + return episode_data_index for idx, episode_idx in enumerate(hf_dataset["episode_index"]): if episode_idx != current_episode: # We encountered a new episode, so we append its starting location to the "from" list @@ -303,6 +326,8 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset: This brings the `episode_index` to the required format. """ + if len(hf_dataset) == 0: + return hf_dataset unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist() episode_idx_to_reset_idx_mapping = { ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs) diff --git a/tests/test_examples.py b/tests/test_examples.py index 9881e3fa..a0c60b7e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -111,7 +111,7 @@ def test_examples_2_through_4(): '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', ), - ('split="train[24342:]"', 'split="train[-1:]"'), + ('split=f"train[{first_val_frame_index}:]"', 'split="train[30:]"'), ("num_workers=4", "num_workers=0"), ('device = torch.device("cuda")', 'device = torch.device("cpu")'), ("batch_size=64", "batch_size=1"), From 2b270d085bdc2e7c64920422537f983cc60a095c Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 20 May 2024 18:27:54 +0100 Subject: [PATCH 2/2] Disable online training (#202) Co-authored-by: Remi --- Makefile | 3 +- lerobot/configs/default.yaml | 1 + lerobot/configs/policy/tdmpc.yaml | 3 +- lerobot/scripts/train.py | 162 +----------------------------- 4 files changed, 9 insertions(+), 160 deletions(-) diff --git a/Makefile b/Makefile index a0163f94..9df4fe60 100644 --- a/Makefile +++ b/Makefile @@ -74,6 +74,7 @@ test-diffusion-ete-eval: env.episode_length=8 \ device=cpu \ +# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated. test-tdmpc-ete-train: python lerobot/scripts/train.py \ policy=tdmpc \ @@ -82,7 +83,7 @@ test-tdmpc-ete-train: dataset_repo_id=lerobot/xarm_lift_medium \ wandb.enable=False \ training.offline_steps=2 \ - training.online_steps=2 \ + training.online_steps=0 \ eval.n_episodes=1 \ eval.batch_size=1 \ env.episode_length=2 \ diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index ae36b3e2..e35deba1 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -17,6 +17,7 @@ dataset_repo_id: lerobot/pusht training: offline_steps: ??? + # NOTE: `online_steps` is not implemented yet. It's here as a placeholder. online_steps: ??? online_steps_between_rollouts: ??? online_sampling_ratio: 0.5 diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 7e736850..09326ab4 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -5,7 +5,8 @@ dataset_repo_id: lerobot/xarm_lift_medium training: offline_steps: 25000 - online_steps: 25000 + # TODO(alexander-soare): uncomment when online training gets reinstated + online_steps: 0 # 25000 not implemented yet eval_freq: 5000 online_steps_between_rollouts: 1 online_sampling_ratio: 0.5 diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 7ca7a0b3..58a2bd01 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -18,11 +18,8 @@ import time from copy import deepcopy from pathlib import Path -import datasets import hydra import torch -from datasets import concatenate_datasets -from datasets.utils import disable_progress_bars, enable_progress_bars from omegaconf import DictConfig from lerobot.common.datasets.factory import make_dataset @@ -69,7 +66,6 @@ def make_optimizer_and_scheduler(cfg, policy): cfg.training.adam_eps, cfg.training.adam_weight_decay, ) - assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training." from diffusers.optimization import get_scheduler lr_scheduler = get_scheduler( @@ -211,103 +207,6 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline): logger.log_dict(info, step, mode="eval") -def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float): - """ - Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average). - - Parameters: - - n_off (int): Number of offline samples, each with a sampling weight of 1. - - n_on (int): Number of online samples. - - pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5). - - The total weight of offline samples is n_off * 1.0. - The total weight of offline samples is n_on * w. - The total combined weight of all samples is n_off + n_on * w. - The fraction of the weight that is online is n_on * w / (n_off + n_on * w). - We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on. - The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1)) - """ - assert 0.0 <= pc_on <= 1.0 - return -(n_off * pc_on) / (n_on * (pc_on - 1)) - - -def add_episodes_inplace( - online_dataset: torch.utils.data.Dataset, - concat_dataset: torch.utils.data.ConcatDataset, - sampler: torch.utils.data.WeightedRandomSampler, - hf_dataset: datasets.Dataset, - episode_data_index: dict[str, torch.Tensor], - pc_online_samples: float, -): - """ - Modifies the online_dataset, concat_dataset, and sampler in place by integrating - new episodes from hf_dataset into the online_dataset, updating the concatenated - dataset's structure and adjusting the sampling strategy based on the specified - percentage of online samples. - - Parameters: - - online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated. - - concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines - offline and online datasets, used for sampling purposes. - - sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to - reflect changes in the dataset sizes and specified sampling weights. - - hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added. - - episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices. - They indicate the start index and end index of each episode in the dataset. - - pc_online_samples (float): The target percentage of samples that should come from - the online dataset during sampling operations. - - Raises: - - AssertionError: If the first episode_id or index in hf_dataset is not 0 - """ - first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item() - last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item() - first_index = hf_dataset.select_columns("index")[0]["index"].item() - last_index = hf_dataset.select_columns("index")[-1]["index"].item() - # sanity check - assert first_episode_idx == 0, f"{first_episode_idx=} is not 0" - assert first_index == 0, f"{first_index=} is not 0" - assert first_index == episode_data_index["from"][first_episode_idx].item() - assert last_index == episode_data_index["to"][last_episode_idx].item() - 1 - - if len(online_dataset) == 0: - # initialize online dataset - online_dataset.hf_dataset = hf_dataset - online_dataset.episode_data_index = episode_data_index - else: - # get the starting indices of the new episodes and frames to be added - start_episode_idx = last_episode_idx + 1 - start_index = last_index + 1 - - def shift_indices(episode_index, index): - # note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to - example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index} - return example - - disable_progress_bars() # map has a tqdm progress bar - hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"]) - enable_progress_bars() - - episode_data_index["from"] += start_index - episode_data_index["to"] += start_index - - # extend online dataset - online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset]) - - # update the concatenated dataset length used during sampling - concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets) - - # update the sampling weights for each frame so that online frames get sampled a certain percentage of times - len_online = len(online_dataset) - len_offline = len(concat_dataset) - len_online - weight_offline = 1.0 - weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples) - sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset)) - - # update the total number of samples used during sampling - sampler.num_samples = len(concat_dataset) - - def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): if out_dir is None: raise NotImplementedError() @@ -316,8 +215,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No init_logging() - if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1: - logging.warning("eval.batch_size > 1 not supported for online training steps") + if cfg.training.online_steps > 0: + raise NotImplementedError("Online training is not implemented yet.") # Check device is available get_safe_torch_device(cfg.device, log=True) @@ -395,10 +294,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No dl_iter = cycle(dataloader) policy.train() - step = 0 # number of policy update (forward + backward + optim) is_offline = True - for offline_step in range(cfg.training.offline_steps): - if offline_step == 0: + for step in range(cfg.training.offline_steps): + if step == 0: logging.info("Start offline training on a fixed dataset") batch = next(dl_iter) @@ -415,11 +313,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # so we pass in step + 1. evaluate_and_checkpoint_if_needed(step + 1) - step += 1 - - # create an env dedicated to online episodes collection from policy rollout - online_training_env = make_env(cfg, n_envs=1) - # create an empty online dataset similar to offline dataset online_dataset = deepcopy(offline_dataset) online_dataset.hf_dataset = {} @@ -439,55 +332,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No pin_memory=cfg.device != "cpu", drop_last=False, ) - dl_iter = cycle(dataloader) - - online_step = 0 - is_offline = False - for env_step in range(cfg.training.online_steps): - if env_step == 0: - logging.info("Start online training by interacting with environment") - - policy.eval() - with torch.no_grad(): - eval_info = eval_policy( - online_training_env, - policy, - n_episodes=1, - return_episode_data=True, - start_seed=cfg.training.online_env_seed, - enable_progbar=True, - ) - - add_episodes_inplace( - online_dataset, - concat_dataset, - sampler, - hf_dataset=eval_info["episodes"]["hf_dataset"], - episode_data_index=eval_info["episodes"]["episode_data_index"], - pc_online_samples=cfg.training.online_sampling_ratio, - ) - - policy.train() - for _ in range(cfg.training.online_steps_between_rollouts): - batch = next(dl_iter) - - for key in batch: - batch[key] = batch[key].to(cfg.device, non_blocking=True) - - train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler) - - if step % cfg.training.log_freq == 0: - log_train_info(logger, train_info, step, cfg, online_dataset, is_offline) - - # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed, - # so we pass in step + 1. - evaluate_and_checkpoint_if_needed(step + 1) - - step += 1 - online_step += 1 eval_env.close() - online_training_env.close() logging.info("End of training")