From 35a573c98eeeca6c09e131afe38e6cab16c1c8d3 Mon Sep 17 00:00:00 2001 From: Cadene Date: Fri, 19 Apr 2024 18:17:13 +0000 Subject: [PATCH] Use v1.1, hf_transform_to_torch, Add 3 xarm datasets --- README.md | 2 +- download_and_upload_dataset.py | 35 ++++++++++--------- examples/1_load_hugging_face_dataset.py | 14 +++++--- examples/2_load_lerobot_dataset.py | 5 ++- lerobot/__init__.py | 7 +++- lerobot/common/datasets/aloha.py | 2 +- lerobot/common/datasets/pusht.py | 2 +- lerobot/common/datasets/utils.py | 45 +++++++++++++++++-------- lerobot/common/datasets/xarm.py | 14 ++++++-- lerobot/scripts/eval.py | 17 ++++++---- lerobot/scripts/train.py | 16 ++++++--- tests/test_datasets.py | 37 +++++++++----------- 12 files changed, 122 insertions(+), 74 deletions(-) diff --git a/README.md b/README.md index 202b90e6..a0045bf2 100644 --- a/README.md +++ b/README.md @@ -208,7 +208,7 @@ HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATAS You will need to set the corresponding version as a default argument in your dataset class: ```python - version: str | None = "v1.0", + version: str | None = "v1.1", ``` See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py) diff --git a/download_and_upload_dataset.py b/download_and_upload_dataset.py index 062db690..def1cd59 100644 --- a/download_and_upload_dataset.py +++ b/download_and_upload_dataset.py @@ -19,7 +19,7 @@ from huggingface_hub import HfApi from PIL import Image as PILImage from safetensors.torch import save_file -from lerobot.common.datasets.utils import compute_stats, convert_images_to_channel_first_tensors, flatten_dict +from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transform_to_torch def download_and_upload(root, revision, dataset_id): @@ -127,7 +127,11 @@ def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dat # copy in tests folder, the first episode and the meta_data directory num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0] - hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train") + hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk( + f"tests/data/{dataset_id}/train" + ) + if Path(f"tests/data/{dataset_id}/meta_data").exists(): + shutil.rmtree(f"tests/data/{dataset_id}/meta_data") shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data") @@ -262,8 +266,7 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10): } features = Features(features) hf_dataset = Dataset.from_dict(data_dict, features=features) - hf_dataset = hf_dataset.with_format("torch") - hf_dataset.set_transform(convert_images_to_channel_first_tensors) + hf_dataset.set_transform(hf_transform_to_torch) info = { "fps": fps, @@ -274,13 +277,14 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10): def download_and_upload_xarm(root, revision, dataset_id, fps=15): root = Path(root) - raw_dir = root / f"{dataset_id}_raw" + raw_dir = root / "xarm_datasets_raw" if not raw_dir.exists(): import zipfile import gdown raw_dir.mkdir(parents=True, exist_ok=True) + # from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya" zip_path = raw_dir / "data.zip" gdown.download(url, str(zip_path), quiet=False) @@ -361,8 +365,7 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15): } features = Features(features) hf_dataset = Dataset.from_dict(data_dict, features=features) - hf_dataset = hf_dataset.with_format("torch") - hf_dataset.set_transform(convert_images_to_channel_first_tensors) + hf_dataset.set_transform(hf_transform_to_torch) info = { "fps": fps, @@ -468,8 +471,6 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50): # "next.reward": reward, "next.done": done, # "next.success": success, - "episode_data_index_from": torch.tensor([id_from] * num_frames), - "episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames), } ) @@ -499,8 +500,7 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50): } features = Features(features) hf_dataset = Dataset.from_dict(data_dict, features=features) - hf_dataset = hf_dataset.with_format("torch") - hf_dataset.set_transform(convert_images_to_channel_first_tensors) + hf_dataset.set_transform(hf_transform_to_torch) info = { "fps": fps, @@ -515,11 +515,14 @@ if __name__ == "__main__": dataset_ids = [ "pusht", - # "xarm_lift_medium", - # "aloha_sim_insertion_human", - # "aloha_sim_insertion_scripted", - # "aloha_sim_transfer_cube_human", - # "aloha_sim_transfer_cube_scripted", + "xarm_lift_medium", + "xarm_lift_medium_replay", + "xarm_push_medium", + "xarm_push_medium_replay", + "aloha_sim_insertion_human", + "aloha_sim_insertion_scripted", + "aloha_sim_transfer_cube_human", + "aloha_sim_transfer_cube_scripted", ] for dataset_id in dataset_ids: download_and_upload(root, revision, dataset_id) diff --git a/examples/1_load_hugging_face_dataset.py b/examples/1_load_hugging_face_dataset.py index d70a1286..2b58fbde 100644 --- a/examples/1_load_hugging_face_dataset.py +++ b/examples/1_load_hugging_face_dataset.py @@ -10,10 +10,13 @@ As an example, this script saves frames of episode number 5 of the PushT dataset This script supports several Hugging Face datasets, among which: 1. [Pusht](https://huggingface.co/datasets/lerobot/pusht) 2. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium) -3. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) -4. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted) -5. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) -6. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted) +3. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay) +4. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium) +5. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium_replay) +6. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) +7. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted) +8. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) +9. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted) To try a different Hugging Face dataset, you can replace this line: ```python @@ -22,6 +25,9 @@ hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10 by one of these: ```python hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium", split="train"), 15 +hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium_replay", split="train"), 15 +hf_dataset, fps = load_dataset("lerobot/xarm_push_medium", split="train"), 15 +hf_dataset, fps = load_dataset("lerobot/xarm_push_medium_replay", split="train"), 15 hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_human", split="train"), 50 hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_scripted", split="train"), 50 hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="train"), 50 diff --git a/examples/2_load_lerobot_dataset.py b/examples/2_load_lerobot_dataset.py index e782e66f..d5289699 100644 --- a/examples/2_load_lerobot_dataset.py +++ b/examples/2_load_lerobot_dataset.py @@ -18,7 +18,10 @@ dataset = PushtDataset() ``` by one of these: ```python -dataset = XarmDataset() +dataset = XarmDataset("xarm_lift_medium") +dataset = XarmDataset("xarm_lift_medium_replay") +dataset = XarmDataset("xarm_push_medium") +dataset = XarmDataset("xarm_push_medium_replay") dataset = AlohaDataset("aloha_sim_insertion_human") dataset = AlohaDataset("aloha_sim_insertion_scripted") dataset = AlohaDataset("aloha_sim_transfer_cube_human") diff --git a/lerobot/__init__.py b/lerobot/__init__.py index 83e51c7a..70d7d7b0 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -50,7 +50,12 @@ available_datasets = { "aloha_sim_transfer_cube_scripted", ], "pusht": ["pusht"], - "xarm": ["xarm_lift_medium"], + "xarm": [ + "xarm_lift_medium", + "xarm_lift_medium_replay", + "xarm_push_medium", + "xarm_push_medium_replay", + ], } available_policies = [ diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 6d993df0..b26c1a5c 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -27,7 +27,7 @@ class AlohaDataset(torch.utils.data.Dataset): def __init__( self, dataset_id: str, - version: str | None = "v1.0", + version: str | None = "v1.1", root: Path | None = None, split: str = "train", transform: callable = None, diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 7fdd88e0..fc1a556d 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -29,7 +29,7 @@ class PushtDataset(torch.utils.data.Dataset): def __init__( self, dataset_id: str = "pusht", - version: str | None = "v1.0", + version: str | None = "v1.1", root: Path | None = None, split: str = "train", transform: callable = None, diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 296f7431..e019cc12 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -6,9 +6,11 @@ import datasets import einops import torch import tqdm -from datasets import load_dataset, load_from_disk +from datasets import Image, load_dataset, load_from_disk from huggingface_hub import hf_hub_download +from PIL import Image as PILImage from safetensors.torch import load_file +from torchvision import transforms from lerobot.common.utils.utils import set_global_seed @@ -37,15 +39,32 @@ def unflatten_dict(d, sep="/"): return outdict +def hf_transform_to_torch(items_dict): + """Get a transform function that convert items from Hugging Face dataset (pyarrow) + to torch tensors. Importantly, images are converted from PIL, which corresponds to + a channel last representation (h w c) of uint8 type, to a torch image representation + with channel first (c h w) of float32 type in range [0,1]. + """ + for key in items_dict: + first_item = items_dict[key][0] + if isinstance(first_item, PILImage.Image): + to_tensor = transforms.ToTensor() + items_dict[key] = [to_tensor(img) for img in items_dict[key]] + else: + items_dict[key] = [torch.tensor(x) for x in items_dict[key]] + return items_dict + + def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" if root is not None: hf_dataset = load_from_disk(Path(root) / dataset_id / split) else: + # TODO(rcadene): remove dataset_id everywhere and use repo_id instead repo_id = f"lerobot/{dataset_id}" hf_dataset = load_dataset(repo_id, revision=version, split=split) hf_dataset = hf_dataset.with_format("torch") - hf_dataset.set_transform(convert_images_to_channel_first_tensors) + hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset @@ -172,16 +191,6 @@ def load_previous_and_future_frames( return item -def convert_images_to_channel_first_tensors(examples): - for key in examples: - if examples[key].ndim == 3: # we assume it's an image - # (h w c) -> (c h w) - h, w, c = examples[key].shape - assert c < h and c < w, f"expect a channel last image, but instead {examples[key].shape}" - examples[key] = [img.permute((2, 0, 1)) for img in examples[key]] - return examples - - def get_stats_einops_patterns(hf_dataset): """These einops patterns will be used to aggregate batches and compute statistics. @@ -198,11 +207,19 @@ def get_stats_einops_patterns(hf_dataset): stats_patterns = {} for key, feats_type in hf_dataset.features.items(): - if batch[key].ndim == 4 and isinstance(feats_type, datasets.features.image.Image): + # sanity check that tensors are not float64 + assert batch[key].dtype != torch.float64 + + if isinstance(feats_type, Image): # sanity check that images are channel first _, c, h, w = batch[key].shape assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}" - # convert from (h w c) to (c h w) to fit pytorch convention, then apply reduce + + # sanity check that images are float32 in range [0,1] + assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}" + assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}" + assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}" + stats_patterns[key] = "b c h w -> c 1 1" elif batch[key].ndim == 2: stats_patterns[key] = "b c -> c " diff --git a/lerobot/common/datasets/xarm.py b/lerobot/common/datasets/xarm.py index 4adff9e9..0d995b5e 100644 --- a/lerobot/common/datasets/xarm.py +++ b/lerobot/common/datasets/xarm.py @@ -9,17 +9,25 @@ from lerobot.common.datasets.utils import load_previous_and_future_frames class XarmDataset(torch.utils.data.Dataset): """ https://huggingface.co/datasets/lerobot/xarm_lift_medium + https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay + https://huggingface.co/datasets/lerobot/xarm_push_medium + https://huggingface.co/datasets/lerobot/xarm_push_medium_replay """ # Copied from lerobot/__init__.py - available_datasets = ["xarm_lift_medium"] + available_datasets = [ + "xarm_lift_medium", + "xarm_lift_medium_replay", + "xarm_push_medium", + "xarm_push_medium_replay", + ] fps = 15 image_keys = ["observation.image"] def __init__( self, - dataset_id: str = "xarm_lift_medium", - version: str | None = "v1.0", + dataset_id: str, + version: str | None = "v1.1", root: Path | None = None, split: str = "train", transform: callable = None, diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 0c0e8e8b..28f354e1 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -208,11 +208,12 @@ def eval_policy( max_rewards.extend(batch_max_reward.tolist()) all_successes.extend(batch_success.tolist()) - # similar logic is implemented in dataset preprocessing + # similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`) ep_dicts = [] + episode_data_index = {"from": [], "to": []} num_episodes = dones.shape[0] total_frames = 0 - idx_from = 0 + id_from = 0 for ep_id in range(num_episodes): num_frames = done_indices[ep_id].item() + 1 total_frames += num_frames @@ -227,14 +228,15 @@ def eval_policy( "timestamp": torch.arange(0, num_frames, 1) / fps, "next.done": dones[ep_id, :num_frames], "next.reward": rewards[ep_id, :num_frames].type(torch.float32), - "episode_data_index_from": torch.tensor([idx_from] * num_frames), - "episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames), } for key in observations: ep_dict[key] = observations[key][ep_id][:num_frames] ep_dicts.append(ep_dict) - idx_from += num_frames + episode_data_index["from"].append(id_from) + episode_data_index["to"].append(id_from + num_frames) + + id_from += num_frames # similar logic is implemented in dataset preprocessing if return_episode_data: @@ -307,7 +309,10 @@ def eval_policy( }, } if return_episode_data: - info["episodes"] = hf_dataset + info["episodes"] = { + "hf_dataset": hf_dataset, + "episode_data_index": episode_data_index, + } if max_episodes_rendered > 0: info["videos"] = videos return info diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 473bf237..59fb199a 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -136,6 +136,7 @@ def add_episodes_inplace( concat_dataset: torch.utils.data.ConcatDataset, sampler: torch.utils.data.WeightedRandomSampler, hf_dataset: datasets.Dataset, + episode_data_index: dict[str, torch.Tensor], pc_online_samples: float, ): """ @@ -151,6 +152,8 @@ def add_episodes_inplace( - sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to reflect changes in the dataset sizes and specified sampling weights. - hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added. + - episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices. + They indicate the start index and end index of each episode in the dataset. - pc_online_samples (float): The target percentage of samples that should come from the online dataset during sampling operations. @@ -174,14 +177,15 @@ def add_episodes_inplace( # note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to example["episode_index"] += start_episode example["index"] += start_index - example["episode_data_index_from"] += start_index - example["episode_data_index_to"] += start_index return example disable_progress_bars() # map has a tqdm progress bar hf_dataset = hf_dataset.map(shift_indices) enable_progress_bars() + episode_data_index["from"] += start_index + episode_data_index["to"] += start_index + # extend online dataset online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset]) @@ -334,9 +338,13 @@ def train(cfg: dict, out_dir=None, job_name=None): seed=cfg.seed, ) - online_pc_sampling = cfg.get("demo_schedule", 0.5) add_episodes_inplace( - online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling + online_dataset, + concat_dataset, + sampler, + hf_dataset=eval_info["episodes"]["hf_dataset"], + episode_data_index=eval_info["episodes"]["episode_data_index"], + online_pc_sampling=cfg.get("demo_schedule", 0.5), ) for _ in range(cfg.policy.utd): diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 8b0428d9..fd333be0 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -18,7 +18,6 @@ from lerobot.common.datasets.utils import ( load_previous_and_future_frames, unflatten_dict, ) -from lerobot.common.transforms import Prod from lerobot.common.utils.utils import init_hydra_config from .utils import DEFAULT_CONFIG_PATH, DEVICE @@ -102,22 +101,18 @@ def test_compute_stats_on_xarm(): data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None - # get transform to convert images from uint8 [0,255] to float32 [0,1] - transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0) - dataset = XarmDataset( dataset_id="xarm_lift_medium", root=data_dir, - transform=transform, ) # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched # computation of the statistics. While doing this, we also make sure it works when we don't divide the # dataset into even batches. - computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25)) + computed_stats = compute_stats(dataset.hf_dataset, batch_size=int(len(dataset) * 0.25)) # get einops patterns to aggregate batches and compute statistics - stats_patterns = get_stats_einops_patterns(dataset) + stats_patterns = get_stats_einops_patterns(dataset.hf_dataset) # get all frames from the dataset in the same dtype and range as during compute_stats dataloader = torch.utils.data.DataLoader( @@ -126,18 +121,18 @@ def test_compute_stats_on_xarm(): batch_size=len(dataset), shuffle=False, ) - hf_dataset = next(iter(dataloader)) + full_batch = next(iter(dataloader)) # compute stats based on all frames from the dataset without any batching expected_stats = {} for k, pattern in stats_patterns.items(): expected_stats[k] = {} - expected_stats[k]["mean"] = einops.reduce(hf_dataset[k], pattern, "mean") + expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean") expected_stats[k]["std"] = torch.sqrt( - einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean") + einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean") ) - expected_stats[k]["min"] = einops.reduce(hf_dataset[k], pattern, "min") - expected_stats[k]["max"] = einops.reduce(hf_dataset[k], pattern, "max") + expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min") + expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max") # test computed stats match expected stats for k in stats_patterns: @@ -146,17 +141,15 @@ def test_compute_stats_on_xarm(): assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"]) assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"]) - # TODO(rcadene): check that the stats used for training are correct too - # # load stats that are expected to match the ones returned by computed_stats - # assert (dataset.data_dir / "stats.pth").exists() - # loaded_stats = torch.load(dataset.data_dir / "stats.pth") + # load stats used during training which are expected to match the ones returned by computed_stats + loaded_stats = dataset.stats - # # test loaded stats match expected stats - # for k in stats_patterns: - # assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"]) - # assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"]) - # assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"]) - # assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"]) + # test loaded stats match expected stats + for k in stats_patterns: + assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"]) + assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"]) + assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"]) + assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"]) def test_load_previous_and_future_frames_within_tolerance():