diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index a69bc573..671ca9a1 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -43,6 +43,10 @@ def get_stats_einops_patterns(dataset, num_workers=0): # sanity check that tensors are not float64 assert batch[key].dtype != torch.float64 + # NOTE: skip language_instruction embedding in stats computation + if key == "language_instruction": + continue + if isinstance(feats_type, (VideoFrame, Image)): # sanity check that images are channel first _, c, h, w = batch[key].shape diff --git a/lerobot/common/datasets/push_dataset_to_hub/oxe_rlds_format.py b/lerobot/common/datasets/push_dataset_to_hub/oxe_rlds_format.py index 854d31a6..0274062f 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/oxe_rlds_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/oxe_rlds_format.py @@ -8,21 +8,22 @@ Example: --raw-dir /hdd/tensorflow_datasets/bridge_dataset/1.0.0/ \ --repo-id youliangtan/sampled_bridge_data_v2 \ --raw-format oxe_rlds \ - --episodes 3 4 5 8 9 + --episodes 3 4 5 8 9 \ + --fps 5 + +Exact dataset fps is specified in: + https://docs.google.com/spreadsheets/d/1rPBD77tk60AEIGZrGSODwyyzs5FgCU9Uz3h-3_t2A9g/edit?gid=0#gid=0&range=R:R """ -import gc import shutil from pathlib import Path -import h5py import numpy as np +import tensorflow_datasets as tfds import torch import tqdm from datasets import Dataset, Features, Image, Sequence, Value from PIL import Image as PILImage -import tensorflow_datasets as tfds -import cv2 from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.utils import ( @@ -42,13 +43,7 @@ def tf_to_torch(data): return torch.from_numpy(data.numpy()) -def load_from_raw( - raw_dir: Path, - videos_dir: Path, - fps: int, - video: bool, - episodes: list[int] | None = None -): +def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None): """ Args: raw_dir (Path): _description_ @@ -58,13 +53,19 @@ def load_from_raw( episodes (list[int] | None, optional): _description_. Defaults to None. """ ds_builder = tfds.builder_from_directory(str(raw_dir)) - dataset = ds_builder.as_dataset(split='all') + dataset = ds_builder.as_dataset(split="all") dataset_info = ds_builder.info print("dataset_info: ", dataset_info) - image_keys = get_cameras_keys( - dataset_info.features["steps"]["observation"].keys()) - print("image_keys: ", image_keys) + image_keys = get_cameras_keys(dataset_info.features["steps"]["observation"].keys()) + + # check if there's a 'tfds.features.Text' in step, only take 1 lang instruction + lang_key = [ + key for key, value in dataset_info.features["steps"].items() if isinstance(value, tfds.features.Text) + ] + lang_key = None if len(lang_key) == 0 else lang_key[0] + print(" - image_keys: ", image_keys) + print(" - lang_key: ", lang_key) ds_length = len(dataset) dataset = dataset.take(ds_length) @@ -88,40 +89,41 @@ def load_from_raw( break if ep_idx == episodes[0]: # process this episode - print(" selecting episode: ", ep_idx) + print(" selecting episode idx: ", ep_idx) episodes.pop(0) else: continue # skip - steps = episode['steps'] - eps_len = len(steps) - num_frames = eps_len # TODO: check if this is correct + steps = episode["steps"] + num_frames = len(steps) # last step of demonstration is considered done done = torch.zeros(num_frames, dtype=torch.bool) done[-1] = True states = [] - actions = [] + actions = [] # TODO(YL): some actions can be a featuredict rewards = torch.zeros(num_frames, dtype=torch.float32) ep_dict = {} + langs = [] image_array_dict = {key: [] for key in image_keys} ########################################################### # loop through all steps in the episode for j, step in enumerate(steps): - states.append(tf_to_torch(step['observation']['state'])) - actions.append(tf_to_torch(step['action'])) - rewards[j] = torch.tensor(step['reward'].numpy(), dtype=torch.float32) + states.append(tf_to_torch(step["observation"]["state"])) + actions.append(tf_to_torch(step["action"])) + rewards[j] = torch.tensor(step["reward"].numpy(), dtype=torch.float32) - # TODO: language_text, is_terminal, is_last etc. + if lang_key is not None: + langs.append(str(step[lang_key])) for im_key in image_keys: - if im_key not in step['observation']: + if im_key not in step["observation"]: continue - img = step['observation'][im_key] + img = step["observation"][im_key] img = np.array(img) image_array_dict[im_key].append(img) @@ -152,6 +154,9 @@ def load_from_raw( else: ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array] + if lang_key is not None: + ep_dict["language_instruction"] = langs + ep_dict["observation.state"] = torch.stack(states) # TODO better way ep_dict["action"] = torch.stack(actions) ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps @@ -180,22 +185,21 @@ def to_hf_dataset(data_dict, video) -> Dataset: features[key] = Image() features["observation.state"] = Sequence( - length=data_dict["observation.state"].shape[1], feature=Value( - dtype="float32", id=None) + length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) ) if "observation.velocity" in data_dict: features["observation.velocity"] = Sequence( - length=data_dict["observation.velocity"].shape[1], feature=Value( - dtype="float32", id=None) + length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None) ) if "observation.effort" in data_dict: features["observation.effort"] = Sequence( - length=data_dict["observation.effort"].shape[1], feature=Value( - dtype="float32", id=None) + length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None) ) + if "language_instruction" in data_dict: + features["language_instruction"] = Value(dtype="string", id=None) + features["action"] = Sequence( - length=data_dict["action"].shape[1], feature=Value( - dtype="float32", id=None) + length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None) ) features["episode_index"] = Value(dtype="int64", id=None) features["frame_index"] = Value(dtype="int64", id=None) @@ -231,11 +235,17 @@ def from_raw_to_lerobot_format( if __name__ == "__main__": - # TODO (YL) remove this + # NOTE (YL): This mainly serves as a unit test + # austin_buds_dataset_converted_externally_to_rlds is a smaller dataset in + # open x embodiment datasets. raw_dir = Path("/hdd/tensorflow_datasets/austin_buds_dataset_converted_externally_to_rlds/0.1.0/") videos_dir = Path("/hdd/tmp/") hf_dataset, episode_data_index, info = from_raw_to_lerobot_format( - raw_dir, videos_dir, fps=5, video=True, episodes=None, + raw_dir, + videos_dir, + fps=50, + video=True, + episodes=[2, 3], ) print(hf_dataset) print(episode_data_index) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index cb2fee95..d9b97c98 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -59,17 +59,35 @@ def unflatten_dict(d, sep="/"): return outdict -def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): +def hf_transform_to_torch( + items_dict: dict[torch.Tensor | None], + lang_tokenizer_name: str = "t5-small", +): """Get a transform function that convert items from Hugging Face dataset (pyarrow) to torch tensors. Importantly, images are converted from PIL, which corresponds to a channel last representation (h w c) of uint8 type, to a torch image representation with channel first (c h w) of float32 type in range [0,1]. """ + # tokenize language instructions if it exists + if "language_instruction" in items_dict: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(lang_tokenizer_name) + tokenizer_kwargs = { + "max_length": 64, # NOTE: adjust this value accordingly + "padding": "max_length", + "truncation": True, + "return_tensors": "pt", + } + for key in items_dict: first_item = items_dict[key][0] if isinstance(first_item, PILImage.Image): to_tensor = transforms.ToTensor() items_dict[key] = [to_tensor(img) for img in items_dict[key]] + elif isinstance(first_item, str): + # convert str to lang embeddings via language tokenizer + items_dict[key] = [tokenizer.encode(x, **tokenizer_kwargs) for x in items_dict[key]] elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item: # video frame will be processed downstream pass