Improve dataset examples (#82)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
parent
d5c4b0c344
commit
0928afd37d
|
@ -200,13 +200,13 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
|||
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
dataset = Dataset.from_dict(data_dict, features=features)
|
||||
dataset = dataset.with_format("torch")
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
|
||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
||||
dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
|
||||
|
||||
def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
||||
|
@ -311,13 +311,13 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
|||
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
dataset = Dataset.from_dict(data_dict, features=features)
|
||||
dataset = dataset.with_format("torch")
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
|
||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
||||
dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
|
||||
|
||||
def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
||||
|
@ -460,13 +460,13 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
|||
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
dataset = Dataset.from_dict(data_dict, features=features)
|
||||
dataset = dataset.with_format("torch")
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
|
||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
||||
dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
"""
|
||||
This script demonstrates the visualization of various robotic datasets from Hugging Face hub.
|
||||
It covers the steps from loading the datasets, filtering specific episodes, and converting the frame data to MP4 videos.
|
||||
Importantly, the dataset format is agnostic to any deep learning library and doesn't require using `lerobot` functions.
|
||||
It is compatible with pytorch, jax, numpy, etc.
|
||||
|
||||
As an example, this script saves frames of episode number 5 of the PushT dataset to a mp4 video and saves the result here:
|
||||
`outputs/examples/1_visualize_hugging_face_datasets/episode_5.mp4`
|
||||
|
||||
This script supports several Hugging Face datasets, among which:
|
||||
1. [Pusht](https://huggingface.co/datasets/lerobot/pusht)
|
||||
2. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium)
|
||||
3. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
||||
4. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
||||
5. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
||||
6. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
||||
|
||||
To try a different Hugging Face dataset, you can replace this line:
|
||||
```python
|
||||
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
||||
```
|
||||
by one of these:
|
||||
```python
|
||||
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium", split="train"), 15
|
||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_human", split="train"), 50
|
||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_scripted", split="train"), 50
|
||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="train"), 50
|
||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_scripted", split="train"), 50
|
||||
```
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import imageio
|
||||
from datasets import load_dataset
|
||||
|
||||
# TODO(rcadene): list available datasets on lerobot page using `datasets`
|
||||
|
||||
# download/load hugging face dataset in pyarrow format
|
||||
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
||||
|
||||
# display name of dataset and its features
|
||||
print(f"{hf_dataset=}")
|
||||
print(f"{hf_dataset.features=}")
|
||||
|
||||
# display useful statistics about frames and episodes, which are sequences of frames from the same video
|
||||
print(f"number of frames: {len(hf_dataset)=}")
|
||||
print(f"number of episodes: {len(hf_dataset.unique('episode_id'))=}")
|
||||
print(f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_id')):.3f}")
|
||||
|
||||
# select the frames belonging to episode number 5
|
||||
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
|
||||
|
||||
# load all frames of episode 5 in RAM in PIL format
|
||||
frames = hf_dataset["observation.image"]
|
||||
|
||||
# save episode frames to a mp4 video
|
||||
Path("outputs/examples/1_load_hugging_face_dataset").mkdir(parents=True, exist_ok=True)
|
||||
imageio.mimsave("outputs/examples/1_load_hugging_face_dataset/episode_5.mp4", frames, fps=fps)
|
|
@ -1,20 +0,0 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.pusht import PushtDataset
|
||||
from lerobot.scripts.visualize_dataset import render_dataset
|
||||
|
||||
print(lerobot.available_datasets)
|
||||
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
|
||||
|
||||
# TODO(rcadene): remove DATA_DIR
|
||||
dataset = PushtDataset("pusht", root=Path(os.environ.get("DATA_DIR")))
|
||||
|
||||
video_paths = render_dataset(
|
||||
dataset,
|
||||
out_dir="outputs/visualize_dataset/example",
|
||||
max_num_episodes=1,
|
||||
)
|
||||
print(video_paths)
|
||||
# ['outputs/visualize_dataset/example/episode_0.mp4']
|
|
@ -0,0 +1,98 @@
|
|||
"""
|
||||
This script demonstrates the use of the PushtDataset class for handling and processing robotic datasets from Hugging Face.
|
||||
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
||||
|
||||
Features included in this script:
|
||||
- Loading a dataset and accessing its properties.
|
||||
- Filtering data by episode number.
|
||||
- Converting tensor data for visualization.
|
||||
- Saving video files from dataset frames.
|
||||
- Using advanced dataset features like timestamp-based frame selection.
|
||||
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
|
||||
|
||||
The script ends with examples of how to batch process data using PyTorch's DataLoader.
|
||||
|
||||
To try a different Hugging Face dataset, you can replace:
|
||||
```python
|
||||
dataset = PushtDataset()
|
||||
```
|
||||
by one of these:
|
||||
```python
|
||||
dataset = XarmDataset()
|
||||
dataset = AlohaDataset("aloha_sim_insertion_human")
|
||||
dataset = AlohaDataset("aloha_sim_insertion_scripted")
|
||||
dataset = AlohaDataset("aloha_sim_transfer_cube_human")
|
||||
dataset = AlohaDataset("aloha_sim_transfer_cube_scripted")
|
||||
```
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import imageio
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.pusht import PushtDataset
|
||||
|
||||
# TODO(rcadene): List available datasets and their dataset ids (e.g. PushtDataset, AlohaDataset(dataset_id="aloha_sim_insertion_human"))
|
||||
# print("List of available datasets", lerobot.available_datasets)
|
||||
# # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted',
|
||||
# # 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted',
|
||||
# # 'pusht', 'xarm_lift_medium']
|
||||
|
||||
|
||||
# You can easily load datasets from LeRobot
|
||||
dataset = PushtDataset()
|
||||
|
||||
# All LeRobot datasets are actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
|
||||
print(f"{dataset=}")
|
||||
print(f"{dataset.hf_dataset=}")
|
||||
|
||||
# and provide additional utilities for robotics and compatibility with pytorch
|
||||
print(f"number of samples/frames: {dataset.num_samples=}")
|
||||
print(f"number of episodes: {dataset.num_episodes=}")
|
||||
print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
|
||||
print(f"frames per second used during data collection: {dataset.fps=}")
|
||||
print(f"keys to access images from cameras: {dataset.image_keys=}")
|
||||
|
||||
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
|
||||
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
|
||||
|
||||
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames.
|
||||
frames = [sample["observation.image"] for sample in dataset]
|
||||
|
||||
# but frames are now channel first to follow pytorch convention,
|
||||
# to view them, we convert to channel last
|
||||
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
||||
|
||||
# and finally save them to a mp4 video
|
||||
Path("outputs/examples/2_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
|
||||
imageio.mimsave("outputs/examples/2_load_lerobot_dataset/episode_5.mp4", frames, fps=dataset.fps)
|
||||
|
||||
# For many machine learning applications we need to load histories of past observations, or trajectorys of future actions. Our datasets can load previous and future frames for each key/modality,
|
||||
# using timestamps differences with the current loaded frame. For instance:
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
"observation.image": [-1, -0.5, -0.20, 0],
|
||||
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, -0.02, -0.01, 0],
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / dataset.fps for t in range(64)],
|
||||
}
|
||||
dataset = PushtDataset(delta_timestamps=delta_timestamps)
|
||||
print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
|
||||
print(f"{dataset[0]['action'].shape=}") # (64,c)
|
||||
|
||||
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers
|
||||
# because they are just PyTorch datasets.
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
)
|
||||
for batch in dataloader:
|
||||
print(f"{batch['observation.image'].shape=}") # (32,4,c,h,w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32,8,c)
|
||||
print(f"{batch['action'].shape=}") # (32,64,c)
|
||||
break
|
|
@ -40,31 +40,31 @@ class AlohaDataset(torch.utils.data.Dataset):
|
|||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
if self.root is not None:
|
||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.data_dict = load_dataset(
|
||||
self.hf_dataset = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
self.hf_dataset = self.hf_dataset.with_format("torch")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict)
|
||||
return len(self.hf_dataset)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_dict.unique("episode_id"))
|
||||
return len(self.hf_dataset.unique("episode_id"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data_dict[idx]
|
||||
item = self.hf_dataset[idx]
|
||||
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.data_dict,
|
||||
self.hf_dataset,
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ class PushtDataset(torch.utils.data.Dataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
dataset_id: str = "pusht",
|
||||
version: str | None = "v1.0",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
|
@ -38,31 +38,31 @@ class PushtDataset(torch.utils.data.Dataset):
|
|||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
if self.root is not None:
|
||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.data_dict = load_dataset(
|
||||
self.hf_dataset = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
self.hf_dataset = self.hf_dataset.with_format("torch")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict)
|
||||
return len(self.hf_dataset)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_dict.unique("episode_id"))
|
||||
return len(self.hf_dataset.unique("episode_id"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data_dict[idx]
|
||||
item = self.hf_dataset[idx]
|
||||
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.data_dict,
|
||||
self.hf_dataset,
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from copy import deepcopy
|
||||
from math import ceil
|
||||
|
||||
import datasets
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
|
@ -8,7 +9,7 @@ import tqdm
|
|||
|
||||
def load_previous_and_future_frames(
|
||||
item: dict[str, torch.Tensor],
|
||||
data_dict: dict[str, torch.Tensor],
|
||||
hf_dataset: datasets.Dataset,
|
||||
delta_timestamps: dict[str, list[float]],
|
||||
tol: float = 0.04,
|
||||
) -> dict[torch.Tensor]:
|
||||
|
@ -24,7 +25,7 @@ def load_previous_and_future_frames(
|
|||
|
||||
Parameters:
|
||||
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||
- data_dict (dict): A dictionary containing the full dataset. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
||||
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04.
|
||||
|
||||
|
@ -40,7 +41,7 @@ def load_previous_and_future_frames(
|
|||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
||||
|
||||
# load timestamps
|
||||
ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
||||
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
||||
|
||||
# we make the assumption that the timestamps are sorted
|
||||
ep_first_ts = ep_timestamps[0]
|
||||
|
@ -70,7 +71,7 @@ def load_previous_and_future_frames(
|
|||
data_ids = ep_data_ids[argmin_]
|
||||
|
||||
# load frames modality
|
||||
item[key] = data_dict.select_columns(key)[data_ids][key]
|
||||
item[key] = hf_dataset.select_columns(key)[data_ids][key]
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
|
||||
return item
|
||||
|
|
|
@ -19,7 +19,7 @@ class XarmDataset(torch.utils.data.Dataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
dataset_id: str = "xarm_lift_medium",
|
||||
version: str | None = "v1.0",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
|
@ -34,31 +34,31 @@ class XarmDataset(torch.utils.data.Dataset):
|
|||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
if self.root is not None:
|
||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.data_dict = load_dataset(
|
||||
self.hf_dataset = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
self.hf_dataset = self.hf_dataset.with_format("torch")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict)
|
||||
return len(self.hf_dataset)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_dict.unique("episode_id"))
|
||||
return len(self.hf_dataset.unique("episode_id"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data_dict[idx]
|
||||
item = self.hf_dataset[idx]
|
||||
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.data_dict,
|
||||
self.hf_dataset,
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
|
|
|
@ -241,7 +241,7 @@ def eval_policy(
|
|||
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
data_dict = Dataset.from_dict(data_dict).with_format("torch")
|
||||
hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||
|
@ -292,7 +292,7 @@ def eval_policy(
|
|||
"eval_s": time.time() - start,
|
||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||
},
|
||||
"episodes": data_dict,
|
||||
"episodes": hf_dataset,
|
||||
}
|
||||
if max_episodes_rendered > 0:
|
||||
info["videos"] = videos
|
||||
|
|
|
@ -2,10 +2,11 @@ import logging
|
|||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import hydra
|
||||
import torch
|
||||
from datasets import concatenate_datasets
|
||||
from datasets.utils.logging import disable_progress_bar
|
||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
|
@ -130,15 +131,40 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
|
|||
return -(n_off * pc_on) / (n_on * (pc_on - 1))
|
||||
|
||||
|
||||
def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_online_samples):
|
||||
first_episode_id = data_dict.select_columns("episode_id")[0]["episode_id"].item()
|
||||
first_index = data_dict.select_columns("index")[0]["index"].item()
|
||||
def add_episodes_inplace(
|
||||
online_dataset: torch.utils.data.Dataset,
|
||||
concat_dataset: torch.utils.data.ConcatDataset,
|
||||
sampler: torch.utils.data.WeightedRandomSampler,
|
||||
hf_dataset: datasets.Dataset,
|
||||
pc_online_samples: float,
|
||||
):
|
||||
"""
|
||||
Modifies the online_dataset, concat_dataset, and sampler in place by integrating
|
||||
new episodes from hf_dataset into the online_dataset, updating the concatenated
|
||||
dataset's structure and adjusting the sampling strategy based on the specified
|
||||
percentage of online samples.
|
||||
|
||||
Parameters:
|
||||
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
|
||||
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
|
||||
offline and online datasets, used for sampling purposes.
|
||||
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
|
||||
reflect changes in the dataset sizes and specified sampling weights.
|
||||
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
||||
- pc_online_samples (float): The target percentage of samples that should come from
|
||||
the online dataset during sampling operations.
|
||||
|
||||
Raises:
|
||||
- AssertionError: If the first episode_id or index in hf_dataset is not 0
|
||||
"""
|
||||
first_episode_id = hf_dataset.select_columns("episode_id")[0]["episode_id"].item()
|
||||
first_index = hf_dataset.select_columns("index")[0]["index"].item()
|
||||
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
|
||||
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
|
||||
|
||||
if len(online_dataset) == 0:
|
||||
# initialize online dataset
|
||||
online_dataset.data_dict = data_dict
|
||||
online_dataset.hf_dataset = hf_dataset
|
||||
else:
|
||||
# find episode index and data frame indices according to previous episode in online_dataset
|
||||
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
|
||||
|
@ -152,11 +178,12 @@ def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_
|
|||
example["episode_data_index_to"] += start_index
|
||||
return example
|
||||
|
||||
disable_progress_bar() # map has a tqdm progress bar
|
||||
data_dict = data_dict.map(shift_indices)
|
||||
disable_progress_bars() # map has a tqdm progress bar
|
||||
hf_dataset = hf_dataset.map(shift_indices)
|
||||
enable_progress_bars()
|
||||
|
||||
# extend online dataset
|
||||
online_dataset.data_dict = concatenate_datasets([online_dataset.data_dict, data_dict])
|
||||
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
|
||||
|
||||
# update the concatenated dataset length used during sampling
|
||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||
|
@ -274,7 +301,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
|
||||
# create an empty online dataset similar to offline dataset
|
||||
online_dataset = deepcopy(offline_dataset)
|
||||
online_dataset.data_dict = {}
|
||||
online_dataset.hf_dataset = {}
|
||||
|
||||
# create dataloader for online training
|
||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||
|
@ -308,7 +335,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
|
||||
online_pc_sampling = cfg.get("demo_schedule", 0.5)
|
||||
add_episodes_inplace(
|
||||
eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling
|
||||
online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling
|
||||
)
|
||||
|
||||
for _ in range(cfg.policy.utd):
|
||||
|
|
|
@ -1,63 +0,0 @@
|
|||
"""
|
||||
This script is designed to facilitate the creation of a subset of an existing dataset by selecting a specific number of frames from the original dataset.
|
||||
This subset can then be used for running quick unit tests.
|
||||
The script takes an input directory containing the original dataset and an output directory where the subset of the dataset will be saved.
|
||||
Additionally, the number of frames to include in the subset can be specified.
|
||||
The script ensures that the subset is a representative sample of the original dataset by copying the specified number of frames and retaining the structure and format of the data.
|
||||
|
||||
Usage:
|
||||
Run the script with the following command, specifying the path to the input data directory,
|
||||
the path to the output data directory, and optionally the number of frames to include in the subset dataset:
|
||||
|
||||
`python tests/scripts/mock_dataset.py --in-data-dir path/to/input_data --out-data-dir path/to/output_data`
|
||||
|
||||
Example:
|
||||
`python tests/scripts/mock_dataset.py --in-data-dir data/pusht --out-data-dir tests/data/pusht`
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def mock_dataset(in_data_dir, out_data_dir, num_frames):
|
||||
in_data_dir = Path(in_data_dir)
|
||||
out_data_dir = Path(out_data_dir)
|
||||
out_data_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# copy the first `n` frames for each data key so that we have real data
|
||||
in_data_dict = torch.load(in_data_dir / "data_dict.pth")
|
||||
out_data_dict = {key: in_data_dict[key][:num_frames].clone() for key in in_data_dict}
|
||||
torch.save(out_data_dict, out_data_dir / "data_dict.pth")
|
||||
|
||||
# recreate data_ids_per_episode that corresponds to the subset
|
||||
episodes = in_data_dict["episode"][:num_frames].tolist()
|
||||
data_ids_per_episode = {}
|
||||
for idx, ep_id in enumerate(episodes):
|
||||
if ep_id not in data_ids_per_episode:
|
||||
data_ids_per_episode[ep_id] = []
|
||||
data_ids_per_episode[ep_id].append(idx)
|
||||
for ep_id in data_ids_per_episode:
|
||||
data_ids_per_episode[ep_id] = torch.tensor(data_ids_per_episode[ep_id])
|
||||
torch.save(data_ids_per_episode, out_data_dir / "data_ids_per_episode.pth")
|
||||
|
||||
# copy the full statistics of dataset since it's small
|
||||
in_stats_path = in_data_dir / "stats.pth"
|
||||
out_stats_path = out_data_dir / "stats.pth"
|
||||
shutil.copy(in_stats_path, out_stats_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="Create a dataset with a subset of frames for quick testing.")
|
||||
|
||||
parser.add_argument("--in-data-dir", type=str, help="Path to input data")
|
||||
parser.add_argument("--out-data-dir", type=str, help="Path to save the output data")
|
||||
parser.add_argument("--num-frames", type=int, default=50, help="Number of frames to copy over")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
mock_dataset(args.in_data_dir, args.out_data_dir, args.num_frames)
|
|
@ -50,7 +50,7 @@ def test_factory(env_name, dataset_id, policy_name):
|
|||
keys_ndim_required.append(
|
||||
(key, 3, True),
|
||||
)
|
||||
assert dataset.data_dict[key].dtype == torch.uint8, f"{key}"
|
||||
assert dataset.hf_dataset[key].dtype == torch.uint8, f"{key}"
|
||||
|
||||
# test number of dimensions
|
||||
for key, ndim, required in keys_ndim_required:
|
||||
|
@ -121,16 +121,16 @@ def test_compute_stats():
|
|||
batch_size=len(dataset),
|
||||
shuffle=False,
|
||||
)
|
||||
data_dict = next(iter(dataloader))
|
||||
hf_dataset = next(iter(dataloader))
|
||||
|
||||
# compute stats based on all frames from the dataset without any batching
|
||||
expected_stats = {}
|
||||
for k, pattern in stats_patterns.items():
|
||||
expected_stats[k] = {}
|
||||
expected_stats[k]["mean"] = einops.reduce(data_dict[k], pattern, "mean")
|
||||
expected_stats[k]["std"] = torch.sqrt(einops.reduce((data_dict[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean"))
|
||||
expected_stats[k]["min"] = einops.reduce(data_dict[k], pattern, "min")
|
||||
expected_stats[k]["max"] = einops.reduce(data_dict[k], pattern, "max")
|
||||
expected_stats[k]["mean"] = einops.reduce(hf_dataset[k], pattern, "mean")
|
||||
expected_stats[k]["std"] = torch.sqrt(einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean"))
|
||||
expected_stats[k]["min"] = einops.reduce(hf_dataset[k], pattern, "min")
|
||||
expected_stats[k]["max"] = einops.reduce(hf_dataset[k], pattern, "max")
|
||||
|
||||
# test computed stats match expected stats
|
||||
for k in stats_patterns:
|
||||
|
@ -153,47 +153,47 @@ def test_compute_stats():
|
|||
|
||||
|
||||
def test_load_previous_and_future_frames_within_tolerance():
|
||||
data_dict = Dataset.from_dict({
|
||||
hf_dataset = Dataset.from_dict({
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_data_index_from": [0, 0, 0, 0, 0],
|
||||
"episode_data_index_to": [5, 5, 5, 5, 5],
|
||||
})
|
||||
data_dict = data_dict.with_format("torch")
|
||||
item = data_dict[2]
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
item = hf_dataset[2]
|
||||
delta_timestamps = {"index": [-0.2, 0, 0.139]}
|
||||
tol = 0.04
|
||||
item = load_previous_and_future_frames(item, data_dict, delta_timestamps, tol)
|
||||
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
|
||||
data, is_pad = item["index"], item["index_is_pad"]
|
||||
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
|
||||
assert not is_pad.any(), "Unexpected padding detected"
|
||||
|
||||
def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range():
|
||||
data_dict = Dataset.from_dict({
|
||||
hf_dataset = Dataset.from_dict({
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_data_index_from": [0, 0, 0, 0, 0],
|
||||
"episode_data_index_to": [5, 5, 5, 5, 5],
|
||||
})
|
||||
data_dict = data_dict.with_format("torch")
|
||||
item = data_dict[2]
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
item = hf_dataset[2]
|
||||
delta_timestamps = {"index": [-0.2, 0, 0.141]}
|
||||
tol = 0.04
|
||||
with pytest.raises(AssertionError):
|
||||
load_previous_and_future_frames(item, data_dict, delta_timestamps, tol)
|
||||
load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
|
||||
|
||||
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
|
||||
data_dict = Dataset.from_dict({
|
||||
hf_dataset = Dataset.from_dict({
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_data_index_from": [0, 0, 0, 0, 0],
|
||||
"episode_data_index_to": [5, 5, 5, 5, 5],
|
||||
})
|
||||
data_dict = data_dict.with_format("torch")
|
||||
item = data_dict[2]
|
||||
hf_dataset = hf_dataset.with_format("torch")
|
||||
item = hf_dataset[2]
|
||||
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
|
||||
tol = 0.04
|
||||
item = load_previous_and_future_frames(item, data_dict, delta_timestamps, tol)
|
||||
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
|
||||
data, is_pad = item["index"], item["index_is_pad"]
|
||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
||||
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), "Padding does not match expected values"
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from pathlib import Path
|
||||
import subprocess
|
||||
|
||||
|
||||
def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str:
|
||||
|
@ -8,23 +9,29 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
|
|||
return text
|
||||
|
||||
|
||||
def _run_script(path):
|
||||
subprocess.run(['python', path], check=True)
|
||||
|
||||
|
||||
def test_example_1():
|
||||
path = "examples/1_visualize_dataset.py"
|
||||
|
||||
with open(path, "r") as file:
|
||||
file_contents = file.read()
|
||||
exec(file_contents)
|
||||
|
||||
assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()
|
||||
path = "examples/1_load_hugging_face_dataset.py"
|
||||
_run_script(path)
|
||||
assert Path("outputs/examples/1_load_hugging_face_dataset/episode_5.mp4").exists()
|
||||
|
||||
|
||||
def test_examples_3_and_2():
|
||||
def test_example_2():
|
||||
path = "examples/2_load_lerobot_dataset.py"
|
||||
_run_script(path)
|
||||
assert Path("outputs/examples/2_load_lerobot_dataset/episode_5.mp4").exists()
|
||||
|
||||
|
||||
def test_examples_4_and_3():
|
||||
"""
|
||||
Train a model with example 3, check the outputs.
|
||||
Evaluate the trained model with example 2, check the outputs.
|
||||
"""
|
||||
|
||||
path = "examples/3_train_policy.py"
|
||||
path = "examples/4_train_policy.py"
|
||||
|
||||
with open(path, "r") as file:
|
||||
file_contents = file.read()
|
||||
|
@ -46,7 +53,7 @@ def test_examples_3_and_2():
|
|||
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
|
||||
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
||||
|
||||
path = "examples/2_evaluate_pretrained_policy.py"
|
||||
path = "examples/3_evaluate_pretrained_policy.py"
|
||||
|
||||
with open(path, "r") as file:
|
||||
file_contents = file.read()
|
||||
|
|
Loading…
Reference in New Issue