Use v1.1, hf_transform_to_torch, Add 3 xarm datasets
This commit is contained in:
parent
714a776277
commit
35a573c98e
|
@ -208,7 +208,7 @@ HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATAS
|
||||||
|
|
||||||
You will need to set the corresponding version as a default argument in your dataset class:
|
You will need to set the corresponding version as a default argument in your dataset class:
|
||||||
```python
|
```python
|
||||||
version: str | None = "v1.0",
|
version: str | None = "v1.1",
|
||||||
```
|
```
|
||||||
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
|
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ from huggingface_hub import HfApi
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import compute_stats, convert_images_to_channel_first_tensors, flatten_dict
|
from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transform_to_torch
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload(root, revision, dataset_id):
|
def download_and_upload(root, revision, dataset_id):
|
||||||
|
@ -127,7 +127,11 @@ def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dat
|
||||||
|
|
||||||
# copy in tests folder, the first episode and the meta_data directory
|
# copy in tests folder, the first episode and the meta_data directory
|
||||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||||
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
|
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
|
||||||
|
f"tests/data/{dataset_id}/train"
|
||||||
|
)
|
||||||
|
if Path(f"tests/data/{dataset_id}/meta_data").exists():
|
||||||
|
shutil.rmtree(f"tests/data/{dataset_id}/meta_data")
|
||||||
shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data")
|
shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data")
|
||||||
|
|
||||||
|
|
||||||
|
@ -262,8 +266,7 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
|
||||||
}
|
}
|
||||||
features = Features(features)
|
features = Features(features)
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
hf_dataset = hf_dataset.with_format("torch")
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
|
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
|
@ -274,13 +277,14 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
|
||||||
|
|
||||||
def download_and_upload_xarm(root, revision, dataset_id, fps=15):
|
def download_and_upload_xarm(root, revision, dataset_id, fps=15):
|
||||||
root = Path(root)
|
root = Path(root)
|
||||||
raw_dir = root / f"{dataset_id}_raw"
|
raw_dir = root / "xarm_datasets_raw"
|
||||||
if not raw_dir.exists():
|
if not raw_dir.exists():
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
import gdown
|
import gdown
|
||||||
|
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
|
||||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||||
zip_path = raw_dir / "data.zip"
|
zip_path = raw_dir / "data.zip"
|
||||||
gdown.download(url, str(zip_path), quiet=False)
|
gdown.download(url, str(zip_path), quiet=False)
|
||||||
|
@ -361,8 +365,7 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
|
||||||
}
|
}
|
||||||
features = Features(features)
|
features = Features(features)
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
hf_dataset = hf_dataset.with_format("torch")
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
|
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
|
@ -468,8 +471,6 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
||||||
# "next.reward": reward,
|
# "next.reward": reward,
|
||||||
"next.done": done,
|
"next.done": done,
|
||||||
# "next.success": success,
|
# "next.success": success,
|
||||||
"episode_data_index_from": torch.tensor([id_from] * num_frames),
|
|
||||||
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -499,8 +500,7 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
||||||
}
|
}
|
||||||
features = Features(features)
|
features = Features(features)
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
hf_dataset = hf_dataset.with_format("torch")
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
|
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
|
@ -515,11 +515,14 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
dataset_ids = [
|
dataset_ids = [
|
||||||
"pusht",
|
"pusht",
|
||||||
# "xarm_lift_medium",
|
"xarm_lift_medium",
|
||||||
# "aloha_sim_insertion_human",
|
"xarm_lift_medium_replay",
|
||||||
# "aloha_sim_insertion_scripted",
|
"xarm_push_medium",
|
||||||
# "aloha_sim_transfer_cube_human",
|
"xarm_push_medium_replay",
|
||||||
# "aloha_sim_transfer_cube_scripted",
|
"aloha_sim_insertion_human",
|
||||||
|
"aloha_sim_insertion_scripted",
|
||||||
|
"aloha_sim_transfer_cube_human",
|
||||||
|
"aloha_sim_transfer_cube_scripted",
|
||||||
]
|
]
|
||||||
for dataset_id in dataset_ids:
|
for dataset_id in dataset_ids:
|
||||||
download_and_upload(root, revision, dataset_id)
|
download_and_upload(root, revision, dataset_id)
|
||||||
|
|
|
@ -10,10 +10,13 @@ As an example, this script saves frames of episode number 5 of the PushT dataset
|
||||||
This script supports several Hugging Face datasets, among which:
|
This script supports several Hugging Face datasets, among which:
|
||||||
1. [Pusht](https://huggingface.co/datasets/lerobot/pusht)
|
1. [Pusht](https://huggingface.co/datasets/lerobot/pusht)
|
||||||
2. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium)
|
2. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium)
|
||||||
3. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
3. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay)
|
||||||
4. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
4. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium)
|
||||||
5. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
5. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium_replay)
|
||||||
6. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
6. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
||||||
|
7. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
||||||
|
8. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
||||||
|
9. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
||||||
|
|
||||||
To try a different Hugging Face dataset, you can replace this line:
|
To try a different Hugging Face dataset, you can replace this line:
|
||||||
```python
|
```python
|
||||||
|
@ -22,6 +25,9 @@ hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
||||||
by one of these:
|
by one of these:
|
||||||
```python
|
```python
|
||||||
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium", split="train"), 15
|
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium", split="train"), 15
|
||||||
|
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium_replay", split="train"), 15
|
||||||
|
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium", split="train"), 15
|
||||||
|
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium_replay", split="train"), 15
|
||||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_human", split="train"), 50
|
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_human", split="train"), 50
|
||||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_scripted", split="train"), 50
|
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_scripted", split="train"), 50
|
||||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="train"), 50
|
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="train"), 50
|
||||||
|
|
|
@ -18,7 +18,10 @@ dataset = PushtDataset()
|
||||||
```
|
```
|
||||||
by one of these:
|
by one of these:
|
||||||
```python
|
```python
|
||||||
dataset = XarmDataset()
|
dataset = XarmDataset("xarm_lift_medium")
|
||||||
|
dataset = XarmDataset("xarm_lift_medium_replay")
|
||||||
|
dataset = XarmDataset("xarm_push_medium")
|
||||||
|
dataset = XarmDataset("xarm_push_medium_replay")
|
||||||
dataset = AlohaDataset("aloha_sim_insertion_human")
|
dataset = AlohaDataset("aloha_sim_insertion_human")
|
||||||
dataset = AlohaDataset("aloha_sim_insertion_scripted")
|
dataset = AlohaDataset("aloha_sim_insertion_scripted")
|
||||||
dataset = AlohaDataset("aloha_sim_transfer_cube_human")
|
dataset = AlohaDataset("aloha_sim_transfer_cube_human")
|
||||||
|
|
|
@ -50,7 +50,12 @@ available_datasets = {
|
||||||
"aloha_sim_transfer_cube_scripted",
|
"aloha_sim_transfer_cube_scripted",
|
||||||
],
|
],
|
||||||
"pusht": ["pusht"],
|
"pusht": ["pusht"],
|
||||||
"xarm": ["xarm_lift_medium"],
|
"xarm": [
|
||||||
|
"xarm_lift_medium",
|
||||||
|
"xarm_lift_medium_replay",
|
||||||
|
"xarm_push_medium",
|
||||||
|
"xarm_push_medium_replay",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
available_policies = [
|
available_policies = [
|
||||||
|
|
|
@ -27,7 +27,7 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
version: str | None = "v1.0",
|
version: str | None = "v1.1",
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
transform: callable = None,
|
transform: callable = None,
|
||||||
|
|
|
@ -29,7 +29,7 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_id: str = "pusht",
|
dataset_id: str = "pusht",
|
||||||
version: str | None = "v1.0",
|
version: str | None = "v1.1",
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
transform: callable = None,
|
transform: callable = None,
|
||||||
|
|
|
@ -6,9 +6,11 @@ import datasets
|
||||||
import einops
|
import einops
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from datasets import load_dataset, load_from_disk
|
from datasets import Image, load_dataset, load_from_disk
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
from PIL import Image as PILImage
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
from lerobot.common.utils.utils import set_global_seed
|
from lerobot.common.utils.utils import set_global_seed
|
||||||
|
|
||||||
|
@ -37,15 +39,32 @@ def unflatten_dict(d, sep="/"):
|
||||||
return outdict
|
return outdict
|
||||||
|
|
||||||
|
|
||||||
|
def hf_transform_to_torch(items_dict):
|
||||||
|
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||||
|
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||||
|
a channel last representation (h w c) of uint8 type, to a torch image representation
|
||||||
|
with channel first (c h w) of float32 type in range [0,1].
|
||||||
|
"""
|
||||||
|
for key in items_dict:
|
||||||
|
first_item = items_dict[key][0]
|
||||||
|
if isinstance(first_item, PILImage.Image):
|
||||||
|
to_tensor = transforms.ToTensor()
|
||||||
|
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||||
|
else:
|
||||||
|
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
||||||
|
return items_dict
|
||||||
|
|
||||||
|
|
||||||
def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
|
def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
|
||||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||||
if root is not None:
|
if root is not None:
|
||||||
hf_dataset = load_from_disk(Path(root) / dataset_id / split)
|
hf_dataset = load_from_disk(Path(root) / dataset_id / split)
|
||||||
else:
|
else:
|
||||||
|
# TODO(rcadene): remove dataset_id everywhere and use repo_id instead
|
||||||
repo_id = f"lerobot/{dataset_id}"
|
repo_id = f"lerobot/{dataset_id}"
|
||||||
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
||||||
hf_dataset = hf_dataset.with_format("torch")
|
hf_dataset = hf_dataset.with_format("torch")
|
||||||
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@ -172,16 +191,6 @@ def load_previous_and_future_frames(
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
|
||||||
def convert_images_to_channel_first_tensors(examples):
|
|
||||||
for key in examples:
|
|
||||||
if examples[key].ndim == 3: # we assume it's an image
|
|
||||||
# (h w c) -> (c h w)
|
|
||||||
h, w, c = examples[key].shape
|
|
||||||
assert c < h and c < w, f"expect a channel last image, but instead {examples[key].shape}"
|
|
||||||
examples[key] = [img.permute((2, 0, 1)) for img in examples[key]]
|
|
||||||
return examples
|
|
||||||
|
|
||||||
|
|
||||||
def get_stats_einops_patterns(hf_dataset):
|
def get_stats_einops_patterns(hf_dataset):
|
||||||
"""These einops patterns will be used to aggregate batches and compute statistics.
|
"""These einops patterns will be used to aggregate batches and compute statistics.
|
||||||
|
|
||||||
|
@ -198,11 +207,19 @@ def get_stats_einops_patterns(hf_dataset):
|
||||||
|
|
||||||
stats_patterns = {}
|
stats_patterns = {}
|
||||||
for key, feats_type in hf_dataset.features.items():
|
for key, feats_type in hf_dataset.features.items():
|
||||||
if batch[key].ndim == 4 and isinstance(feats_type, datasets.features.image.Image):
|
# sanity check that tensors are not float64
|
||||||
|
assert batch[key].dtype != torch.float64
|
||||||
|
|
||||||
|
if isinstance(feats_type, Image):
|
||||||
# sanity check that images are channel first
|
# sanity check that images are channel first
|
||||||
_, c, h, w = batch[key].shape
|
_, c, h, w = batch[key].shape
|
||||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||||
# convert from (h w c) to (c h w) to fit pytorch convention, then apply reduce
|
|
||||||
|
# sanity check that images are float32 in range [0,1]
|
||||||
|
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
|
||||||
|
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
||||||
|
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
||||||
|
|
||||||
stats_patterns[key] = "b c h w -> c 1 1"
|
stats_patterns[key] = "b c h w -> c 1 1"
|
||||||
elif batch[key].ndim == 2:
|
elif batch[key].ndim == 2:
|
||||||
stats_patterns[key] = "b c -> c "
|
stats_patterns[key] = "b c -> c "
|
||||||
|
|
|
@ -9,17 +9,25 @@ from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||||
class XarmDataset(torch.utils.data.Dataset):
|
class XarmDataset(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
https://huggingface.co/datasets/lerobot/xarm_lift_medium
|
https://huggingface.co/datasets/lerobot/xarm_lift_medium
|
||||||
|
https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay
|
||||||
|
https://huggingface.co/datasets/lerobot/xarm_push_medium
|
||||||
|
https://huggingface.co/datasets/lerobot/xarm_push_medium_replay
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Copied from lerobot/__init__.py
|
# Copied from lerobot/__init__.py
|
||||||
available_datasets = ["xarm_lift_medium"]
|
available_datasets = [
|
||||||
|
"xarm_lift_medium",
|
||||||
|
"xarm_lift_medium_replay",
|
||||||
|
"xarm_push_medium",
|
||||||
|
"xarm_push_medium_replay",
|
||||||
|
]
|
||||||
fps = 15
|
fps = 15
|
||||||
image_keys = ["observation.image"]
|
image_keys = ["observation.image"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_id: str = "xarm_lift_medium",
|
dataset_id: str,
|
||||||
version: str | None = "v1.0",
|
version: str | None = "v1.1",
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
transform: callable = None,
|
transform: callable = None,
|
||||||
|
|
|
@ -208,11 +208,12 @@ def eval_policy(
|
||||||
max_rewards.extend(batch_max_reward.tolist())
|
max_rewards.extend(batch_max_reward.tolist())
|
||||||
all_successes.extend(batch_success.tolist())
|
all_successes.extend(batch_success.tolist())
|
||||||
|
|
||||||
# similar logic is implemented in dataset preprocessing
|
# similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`)
|
||||||
ep_dicts = []
|
ep_dicts = []
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
num_episodes = dones.shape[0]
|
num_episodes = dones.shape[0]
|
||||||
total_frames = 0
|
total_frames = 0
|
||||||
idx_from = 0
|
id_from = 0
|
||||||
for ep_id in range(num_episodes):
|
for ep_id in range(num_episodes):
|
||||||
num_frames = done_indices[ep_id].item() + 1
|
num_frames = done_indices[ep_id].item() + 1
|
||||||
total_frames += num_frames
|
total_frames += num_frames
|
||||||
|
@ -227,14 +228,15 @@ def eval_policy(
|
||||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||||
"next.done": dones[ep_id, :num_frames],
|
"next.done": dones[ep_id, :num_frames],
|
||||||
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
||||||
"episode_data_index_from": torch.tensor([idx_from] * num_frames),
|
|
||||||
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
|
|
||||||
}
|
}
|
||||||
for key in observations:
|
for key in observations:
|
||||||
ep_dict[key] = observations[key][ep_id][:num_frames]
|
ep_dict[key] = observations[key][ep_id][:num_frames]
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
idx_from += num_frames
|
episode_data_index["from"].append(id_from)
|
||||||
|
episode_data_index["to"].append(id_from + num_frames)
|
||||||
|
|
||||||
|
id_from += num_frames
|
||||||
|
|
||||||
# similar logic is implemented in dataset preprocessing
|
# similar logic is implemented in dataset preprocessing
|
||||||
if return_episode_data:
|
if return_episode_data:
|
||||||
|
@ -307,7 +309,10 @@ def eval_policy(
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if return_episode_data:
|
if return_episode_data:
|
||||||
info["episodes"] = hf_dataset
|
info["episodes"] = {
|
||||||
|
"hf_dataset": hf_dataset,
|
||||||
|
"episode_data_index": episode_data_index,
|
||||||
|
}
|
||||||
if max_episodes_rendered > 0:
|
if max_episodes_rendered > 0:
|
||||||
info["videos"] = videos
|
info["videos"] = videos
|
||||||
return info
|
return info
|
||||||
|
|
|
@ -136,6 +136,7 @@ def add_episodes_inplace(
|
||||||
concat_dataset: torch.utils.data.ConcatDataset,
|
concat_dataset: torch.utils.data.ConcatDataset,
|
||||||
sampler: torch.utils.data.WeightedRandomSampler,
|
sampler: torch.utils.data.WeightedRandomSampler,
|
||||||
hf_dataset: datasets.Dataset,
|
hf_dataset: datasets.Dataset,
|
||||||
|
episode_data_index: dict[str, torch.Tensor],
|
||||||
pc_online_samples: float,
|
pc_online_samples: float,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -151,6 +152,8 @@ def add_episodes_inplace(
|
||||||
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
|
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
|
||||||
reflect changes in the dataset sizes and specified sampling weights.
|
reflect changes in the dataset sizes and specified sampling weights.
|
||||||
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
||||||
|
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
|
||||||
|
They indicate the start index and end index of each episode in the dataset.
|
||||||
- pc_online_samples (float): The target percentage of samples that should come from
|
- pc_online_samples (float): The target percentage of samples that should come from
|
||||||
the online dataset during sampling operations.
|
the online dataset during sampling operations.
|
||||||
|
|
||||||
|
@ -174,14 +177,15 @@ def add_episodes_inplace(
|
||||||
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
|
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
|
||||||
example["episode_index"] += start_episode
|
example["episode_index"] += start_episode
|
||||||
example["index"] += start_index
|
example["index"] += start_index
|
||||||
example["episode_data_index_from"] += start_index
|
|
||||||
example["episode_data_index_to"] += start_index
|
|
||||||
return example
|
return example
|
||||||
|
|
||||||
disable_progress_bars() # map has a tqdm progress bar
|
disable_progress_bars() # map has a tqdm progress bar
|
||||||
hf_dataset = hf_dataset.map(shift_indices)
|
hf_dataset = hf_dataset.map(shift_indices)
|
||||||
enable_progress_bars()
|
enable_progress_bars()
|
||||||
|
|
||||||
|
episode_data_index["from"] += start_index
|
||||||
|
episode_data_index["to"] += start_index
|
||||||
|
|
||||||
# extend online dataset
|
# extend online dataset
|
||||||
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
|
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
|
||||||
|
|
||||||
|
@ -334,9 +338,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
seed=cfg.seed,
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
online_pc_sampling = cfg.get("demo_schedule", 0.5)
|
|
||||||
add_episodes_inplace(
|
add_episodes_inplace(
|
||||||
online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling
|
online_dataset,
|
||||||
|
concat_dataset,
|
||||||
|
sampler,
|
||||||
|
hf_dataset=eval_info["episodes"]["hf_dataset"],
|
||||||
|
episode_data_index=eval_info["episodes"]["episode_data_index"],
|
||||||
|
online_pc_sampling=cfg.get("demo_schedule", 0.5),
|
||||||
)
|
)
|
||||||
|
|
||||||
for _ in range(cfg.policy.utd):
|
for _ in range(cfg.policy.utd):
|
||||||
|
|
|
@ -18,7 +18,6 @@ from lerobot.common.datasets.utils import (
|
||||||
load_previous_and_future_frames,
|
load_previous_and_future_frames,
|
||||||
unflatten_dict,
|
unflatten_dict,
|
||||||
)
|
)
|
||||||
from lerobot.common.transforms import Prod
|
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
|
||||||
from .utils import DEFAULT_CONFIG_PATH, DEVICE
|
from .utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||||
|
@ -102,22 +101,18 @@ def test_compute_stats_on_xarm():
|
||||||
|
|
||||||
data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||||
|
|
||||||
# get transform to convert images from uint8 [0,255] to float32 [0,1]
|
|
||||||
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
|
|
||||||
|
|
||||||
dataset = XarmDataset(
|
dataset = XarmDataset(
|
||||||
dataset_id="xarm_lift_medium",
|
dataset_id="xarm_lift_medium",
|
||||||
root=data_dir,
|
root=data_dir,
|
||||||
transform=transform,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
||||||
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
||||||
# dataset into even batches.
|
# dataset into even batches.
|
||||||
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25))
|
computed_stats = compute_stats(dataset.hf_dataset, batch_size=int(len(dataset) * 0.25))
|
||||||
|
|
||||||
# get einops patterns to aggregate batches and compute statistics
|
# get einops patterns to aggregate batches and compute statistics
|
||||||
stats_patterns = get_stats_einops_patterns(dataset)
|
stats_patterns = get_stats_einops_patterns(dataset.hf_dataset)
|
||||||
|
|
||||||
# get all frames from the dataset in the same dtype and range as during compute_stats
|
# get all frames from the dataset in the same dtype and range as during compute_stats
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
@ -126,18 +121,18 @@ def test_compute_stats_on_xarm():
|
||||||
batch_size=len(dataset),
|
batch_size=len(dataset),
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
hf_dataset = next(iter(dataloader))
|
full_batch = next(iter(dataloader))
|
||||||
|
|
||||||
# compute stats based on all frames from the dataset without any batching
|
# compute stats based on all frames from the dataset without any batching
|
||||||
expected_stats = {}
|
expected_stats = {}
|
||||||
for k, pattern in stats_patterns.items():
|
for k, pattern in stats_patterns.items():
|
||||||
expected_stats[k] = {}
|
expected_stats[k] = {}
|
||||||
expected_stats[k]["mean"] = einops.reduce(hf_dataset[k], pattern, "mean")
|
expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean")
|
||||||
expected_stats[k]["std"] = torch.sqrt(
|
expected_stats[k]["std"] = torch.sqrt(
|
||||||
einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
|
einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
|
||||||
)
|
)
|
||||||
expected_stats[k]["min"] = einops.reduce(hf_dataset[k], pattern, "min")
|
expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min")
|
||||||
expected_stats[k]["max"] = einops.reduce(hf_dataset[k], pattern, "max")
|
expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max")
|
||||||
|
|
||||||
# test computed stats match expected stats
|
# test computed stats match expected stats
|
||||||
for k in stats_patterns:
|
for k in stats_patterns:
|
||||||
|
@ -146,17 +141,15 @@ def test_compute_stats_on_xarm():
|
||||||
assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
|
assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
|
||||||
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
|
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
|
||||||
|
|
||||||
# TODO(rcadene): check that the stats used for training are correct too
|
# load stats used during training which are expected to match the ones returned by computed_stats
|
||||||
# # load stats that are expected to match the ones returned by computed_stats
|
loaded_stats = dataset.stats
|
||||||
# assert (dataset.data_dir / "stats.pth").exists()
|
|
||||||
# loaded_stats = torch.load(dataset.data_dir / "stats.pth")
|
|
||||||
|
|
||||||
# # test loaded stats match expected stats
|
# test loaded stats match expected stats
|
||||||
# for k in stats_patterns:
|
for k in stats_patterns:
|
||||||
# assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
|
assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
|
||||||
# assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
|
assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
|
||||||
# assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
|
assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
|
||||||
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
|
assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
|
||||||
|
|
||||||
|
|
||||||
def test_load_previous_and_future_frames_within_tolerance():
|
def test_load_previous_and_future_frames_within_tolerance():
|
||||||
|
|
Loading…
Reference in New Issue