Use v1.1, hf_transform_to_torch, Add 3 xarm datasets

This commit is contained in:
Cadene 2024-04-19 18:17:13 +00:00
parent 714a776277
commit 35a573c98e
12 changed files with 122 additions and 74 deletions

View File

@ -208,7 +208,7 @@ HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATAS
You will need to set the corresponding version as a default argument in your dataset class:
```python
version: str | None = "v1.0",
version: str | None = "v1.1",
```
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)

View File

@ -19,7 +19,7 @@ from huggingface_hub import HfApi
from PIL import Image as PILImage
from safetensors.torch import save_file
from lerobot.common.datasets.utils import compute_stats, convert_images_to_channel_first_tensors, flatten_dict
from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transform_to_torch
def download_and_upload(root, revision, dataset_id):
@ -127,7 +127,11 @@ def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dat
# copy in tests folder, the first episode and the meta_data directory
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
f"tests/data/{dataset_id}/train"
)
if Path(f"tests/data/{dataset_id}/meta_data").exists():
shutil.rmtree(f"tests/data/{dataset_id}/meta_data")
shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data")
@ -262,8 +266,7 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
hf_dataset.set_transform(hf_transform_to_torch)
info = {
"fps": fps,
@ -274,13 +277,14 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
def download_and_upload_xarm(root, revision, dataset_id, fps=15):
root = Path(root)
raw_dir = root / f"{dataset_id}_raw"
raw_dir = root / "xarm_datasets_raw"
if not raw_dir.exists():
import zipfile
import gdown
raw_dir.mkdir(parents=True, exist_ok=True)
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
zip_path = raw_dir / "data.zip"
gdown.download(url, str(zip_path), quiet=False)
@ -361,8 +365,7 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
hf_dataset.set_transform(hf_transform_to_torch)
info = {
"fps": fps,
@ -468,8 +471,6 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
# "next.reward": reward,
"next.done": done,
# "next.success": success,
"episode_data_index_from": torch.tensor([id_from] * num_frames),
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
}
)
@ -499,8 +500,7 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
hf_dataset.set_transform(hf_transform_to_torch)
info = {
"fps": fps,
@ -515,11 +515,14 @@ if __name__ == "__main__":
dataset_ids = [
"pusht",
# "xarm_lift_medium",
# "aloha_sim_insertion_human",
# "aloha_sim_insertion_scripted",
# "aloha_sim_transfer_cube_human",
# "aloha_sim_transfer_cube_scripted",
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
]
for dataset_id in dataset_ids:
download_and_upload(root, revision, dataset_id)

View File

@ -10,10 +10,13 @@ As an example, this script saves frames of episode number 5 of the PushT dataset
This script supports several Hugging Face datasets, among which:
1. [Pusht](https://huggingface.co/datasets/lerobot/pusht)
2. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium)
3. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
4. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
5. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
6. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
3. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay)
4. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium)
5. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium_replay)
6. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
7. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
8. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
9. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
To try a different Hugging Face dataset, you can replace this line:
```python
@ -22,6 +25,9 @@ hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
by one of these:
```python
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium", split="train"), 15
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium_replay", split="train"), 15
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium", split="train"), 15
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium_replay", split="train"), 15
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_human", split="train"), 50
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_scripted", split="train"), 50
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="train"), 50

View File

@ -18,7 +18,10 @@ dataset = PushtDataset()
```
by one of these:
```python
dataset = XarmDataset()
dataset = XarmDataset("xarm_lift_medium")
dataset = XarmDataset("xarm_lift_medium_replay")
dataset = XarmDataset("xarm_push_medium")
dataset = XarmDataset("xarm_push_medium_replay")
dataset = AlohaDataset("aloha_sim_insertion_human")
dataset = AlohaDataset("aloha_sim_insertion_scripted")
dataset = AlohaDataset("aloha_sim_transfer_cube_human")

View File

@ -50,7 +50,12 @@ available_datasets = {
"aloha_sim_transfer_cube_scripted",
],
"pusht": ["pusht"],
"xarm": ["xarm_lift_medium"],
"xarm": [
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
],
}
available_policies = [

View File

@ -27,7 +27,7 @@ class AlohaDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_id: str,
version: str | None = "v1.0",
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,

View File

@ -29,7 +29,7 @@ class PushtDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_id: str = "pusht",
version: str | None = "v1.0",
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,

View File

@ -6,9 +6,11 @@ import datasets
import einops
import torch
import tqdm
from datasets import load_dataset, load_from_disk
from datasets import Image, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from PIL import Image as PILImage
from safetensors.torch import load_file
from torchvision import transforms
from lerobot.common.utils.utils import set_global_seed
@ -37,15 +39,32 @@ def unflatten_dict(d, sep="/"):
return outdict
def hf_transform_to_torch(items_dict):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
a channel last representation (h w c) of uint8 type, to a torch image representation
with channel first (c h w) of float32 type in range [0,1].
"""
for key in items_dict:
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
else:
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
return items_dict
def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(Path(root) / dataset_id / split)
else:
# TODO(rcadene): remove dataset_id everywhere and use repo_id instead
repo_id = f"lerobot/{dataset_id}"
hf_dataset = load_dataset(repo_id, revision=version, split=split)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
@ -172,16 +191,6 @@ def load_previous_and_future_frames(
return item
def convert_images_to_channel_first_tensors(examples):
for key in examples:
if examples[key].ndim == 3: # we assume it's an image
# (h w c) -> (c h w)
h, w, c = examples[key].shape
assert c < h and c < w, f"expect a channel last image, but instead {examples[key].shape}"
examples[key] = [img.permute((2, 0, 1)) for img in examples[key]]
return examples
def get_stats_einops_patterns(hf_dataset):
"""These einops patterns will be used to aggregate batches and compute statistics.
@ -198,11 +207,19 @@ def get_stats_einops_patterns(hf_dataset):
stats_patterns = {}
for key, feats_type in hf_dataset.features.items():
if batch[key].ndim == 4 and isinstance(feats_type, datasets.features.image.Image):
# sanity check that tensors are not float64
assert batch[key].dtype != torch.float64
if isinstance(feats_type, Image):
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
# convert from (h w c) to (c h w) to fit pytorch convention, then apply reduce
# sanity check that images are float32 in range [0,1]
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
stats_patterns[key] = "b c h w -> c 1 1"
elif batch[key].ndim == 2:
stats_patterns[key] = "b c -> c "

View File

@ -9,17 +9,25 @@ from lerobot.common.datasets.utils import load_previous_and_future_frames
class XarmDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/xarm_lift_medium
https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay
https://huggingface.co/datasets/lerobot/xarm_push_medium
https://huggingface.co/datasets/lerobot/xarm_push_medium_replay
"""
# Copied from lerobot/__init__.py
available_datasets = ["xarm_lift_medium"]
available_datasets = [
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
]
fps = 15
image_keys = ["observation.image"]
def __init__(
self,
dataset_id: str = "xarm_lift_medium",
version: str | None = "v1.0",
dataset_id: str,
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,

View File

@ -208,11 +208,12 @@ def eval_policy(
max_rewards.extend(batch_max_reward.tolist())
all_successes.extend(batch_success.tolist())
# similar logic is implemented in dataset preprocessing
# similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`)
ep_dicts = []
episode_data_index = {"from": [], "to": []}
num_episodes = dones.shape[0]
total_frames = 0
idx_from = 0
id_from = 0
for ep_id in range(num_episodes):
num_frames = done_indices[ep_id].item() + 1
total_frames += num_frames
@ -227,14 +228,15 @@ def eval_policy(
"timestamp": torch.arange(0, num_frames, 1) / fps,
"next.done": dones[ep_id, :num_frames],
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
"episode_data_index_from": torch.tensor([idx_from] * num_frames),
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
}
for key in observations:
ep_dict[key] = observations[key][ep_id][:num_frames]
ep_dicts.append(ep_dict)
idx_from += num_frames
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
# similar logic is implemented in dataset preprocessing
if return_episode_data:
@ -307,7 +309,10 @@ def eval_policy(
},
}
if return_episode_data:
info["episodes"] = hf_dataset
info["episodes"] = {
"hf_dataset": hf_dataset,
"episode_data_index": episode_data_index,
}
if max_episodes_rendered > 0:
info["videos"] = videos
return info

View File

@ -136,6 +136,7 @@ def add_episodes_inplace(
concat_dataset: torch.utils.data.ConcatDataset,
sampler: torch.utils.data.WeightedRandomSampler,
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
pc_online_samples: float,
):
"""
@ -151,6 +152,8 @@ def add_episodes_inplace(
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
reflect changes in the dataset sizes and specified sampling weights.
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- pc_online_samples (float): The target percentage of samples that should come from
the online dataset during sampling operations.
@ -174,14 +177,15 @@ def add_episodes_inplace(
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
example["episode_index"] += start_episode
example["index"] += start_index
example["episode_data_index_from"] += start_index
example["episode_data_index_to"] += start_index
return example
disable_progress_bars() # map has a tqdm progress bar
hf_dataset = hf_dataset.map(shift_indices)
enable_progress_bars()
episode_data_index["from"] += start_index
episode_data_index["to"] += start_index
# extend online dataset
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
@ -334,9 +338,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
seed=cfg.seed,
)
online_pc_sampling = cfg.get("demo_schedule", 0.5)
add_episodes_inplace(
online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling
online_dataset,
concat_dataset,
sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"],
online_pc_sampling=cfg.get("demo_schedule", 0.5),
)
for _ in range(cfg.policy.utd):

View File

@ -18,7 +18,6 @@ from lerobot.common.datasets.utils import (
load_previous_and_future_frames,
unflatten_dict,
)
from lerobot.common.transforms import Prod
from lerobot.common.utils.utils import init_hydra_config
from .utils import DEFAULT_CONFIG_PATH, DEVICE
@ -102,22 +101,18 @@ def test_compute_stats_on_xarm():
data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
# get transform to convert images from uint8 [0,255] to float32 [0,1]
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
dataset = XarmDataset(
dataset_id="xarm_lift_medium",
root=data_dir,
transform=transform,
)
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25))
computed_stats = compute_stats(dataset.hf_dataset, batch_size=int(len(dataset) * 0.25))
# get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset)
stats_patterns = get_stats_einops_patterns(dataset.hf_dataset)
# get all frames from the dataset in the same dtype and range as during compute_stats
dataloader = torch.utils.data.DataLoader(
@ -126,18 +121,18 @@ def test_compute_stats_on_xarm():
batch_size=len(dataset),
shuffle=False,
)
hf_dataset = next(iter(dataloader))
full_batch = next(iter(dataloader))
# compute stats based on all frames from the dataset without any batching
expected_stats = {}
for k, pattern in stats_patterns.items():
expected_stats[k] = {}
expected_stats[k]["mean"] = einops.reduce(hf_dataset[k], pattern, "mean")
expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean")
expected_stats[k]["std"] = torch.sqrt(
einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
)
expected_stats[k]["min"] = einops.reduce(hf_dataset[k], pattern, "min")
expected_stats[k]["max"] = einops.reduce(hf_dataset[k], pattern, "max")
expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min")
expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max")
# test computed stats match expected stats
for k in stats_patterns:
@ -146,17 +141,15 @@ def test_compute_stats_on_xarm():
assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
# TODO(rcadene): check that the stats used for training are correct too
# # load stats that are expected to match the ones returned by computed_stats
# assert (dataset.data_dir / "stats.pth").exists()
# loaded_stats = torch.load(dataset.data_dir / "stats.pth")
# load stats used during training which are expected to match the ones returned by computed_stats
loaded_stats = dataset.stats
# # test loaded stats match expected stats
# for k in stats_patterns:
# assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
# assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
# assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
# test loaded stats match expected stats
for k in stats_patterns:
assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
def test_load_previous_and_future_frames_within_tolerance():