add lang tokenizer

Signed-off-by: youliangtan <tan_you_liang@hotmail.com>
This commit is contained in:
youliangtan 2024-06-26 19:52:46 -07:00
parent 61e51c9fe4
commit a644084f98
3 changed files with 70 additions and 38 deletions

View File

@ -43,6 +43,10 @@ def get_stats_einops_patterns(dataset, num_workers=0):
# sanity check that tensors are not float64 # sanity check that tensors are not float64
assert batch[key].dtype != torch.float64 assert batch[key].dtype != torch.float64
# NOTE: skip language_instruction embedding in stats computation
if key == "language_instruction":
continue
if isinstance(feats_type, (VideoFrame, Image)): if isinstance(feats_type, (VideoFrame, Image)):
# sanity check that images are channel first # sanity check that images are channel first
_, c, h, w = batch[key].shape _, c, h, w = batch[key].shape

View File

@ -8,21 +8,22 @@ Example:
--raw-dir /hdd/tensorflow_datasets/bridge_dataset/1.0.0/ \ --raw-dir /hdd/tensorflow_datasets/bridge_dataset/1.0.0/ \
--repo-id youliangtan/sampled_bridge_data_v2 \ --repo-id youliangtan/sampled_bridge_data_v2 \
--raw-format oxe_rlds \ --raw-format oxe_rlds \
--episodes 3 4 5 8 9 --episodes 3 4 5 8 9 \
--fps 5
Exact dataset fps is specified in:
https://docs.google.com/spreadsheets/d/1rPBD77tk60AEIGZrGSODwyyzs5FgCU9Uz3h-3_t2A9g/edit?gid=0#gid=0&range=R:R
""" """
import gc
import shutil import shutil
from pathlib import Path from pathlib import Path
import h5py
import numpy as np import numpy as np
import tensorflow_datasets as tfds
import torch import torch
import tqdm import tqdm
from datasets import Dataset, Features, Image, Sequence, Value from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage from PIL import Image as PILImage
import tensorflow_datasets as tfds
import cv2
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
@ -42,13 +43,7 @@ def tf_to_torch(data):
return torch.from_numpy(data.numpy()) return torch.from_numpy(data.numpy())
def load_from_raw( def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
raw_dir: Path,
videos_dir: Path,
fps: int,
video: bool,
episodes: list[int] | None = None
):
""" """
Args: Args:
raw_dir (Path): _description_ raw_dir (Path): _description_
@ -58,13 +53,19 @@ def load_from_raw(
episodes (list[int] | None, optional): _description_. Defaults to None. episodes (list[int] | None, optional): _description_. Defaults to None.
""" """
ds_builder = tfds.builder_from_directory(str(raw_dir)) ds_builder = tfds.builder_from_directory(str(raw_dir))
dataset = ds_builder.as_dataset(split='all') dataset = ds_builder.as_dataset(split="all")
dataset_info = ds_builder.info dataset_info = ds_builder.info
print("dataset_info: ", dataset_info) print("dataset_info: ", dataset_info)
image_keys = get_cameras_keys( image_keys = get_cameras_keys(dataset_info.features["steps"]["observation"].keys())
dataset_info.features["steps"]["observation"].keys())
print("image_keys: ", image_keys) # check if there's a 'tfds.features.Text' in step, only take 1 lang instruction
lang_key = [
key for key, value in dataset_info.features["steps"].items() if isinstance(value, tfds.features.Text)
]
lang_key = None if len(lang_key) == 0 else lang_key[0]
print(" - image_keys: ", image_keys)
print(" - lang_key: ", lang_key)
ds_length = len(dataset) ds_length = len(dataset)
dataset = dataset.take(ds_length) dataset = dataset.take(ds_length)
@ -88,40 +89,41 @@ def load_from_raw(
break break
if ep_idx == episodes[0]: if ep_idx == episodes[0]:
# process this episode # process this episode
print(" selecting episode: ", ep_idx) print(" selecting episode idx: ", ep_idx)
episodes.pop(0) episodes.pop(0)
else: else:
continue # skip continue # skip
steps = episode['steps'] steps = episode["steps"]
eps_len = len(steps) num_frames = len(steps)
num_frames = eps_len # TODO: check if this is correct
# last step of demonstration is considered done # last step of demonstration is considered done
done = torch.zeros(num_frames, dtype=torch.bool) done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True done[-1] = True
states = [] states = []
actions = [] actions = [] # TODO(YL): some actions can be a featuredict
rewards = torch.zeros(num_frames, dtype=torch.float32) rewards = torch.zeros(num_frames, dtype=torch.float32)
ep_dict = {} ep_dict = {}
langs = []
image_array_dict = {key: [] for key in image_keys} image_array_dict = {key: [] for key in image_keys}
########################################################### ###########################################################
# loop through all steps in the episode # loop through all steps in the episode
for j, step in enumerate(steps): for j, step in enumerate(steps):
states.append(tf_to_torch(step['observation']['state'])) states.append(tf_to_torch(step["observation"]["state"]))
actions.append(tf_to_torch(step['action'])) actions.append(tf_to_torch(step["action"]))
rewards[j] = torch.tensor(step['reward'].numpy(), dtype=torch.float32) rewards[j] = torch.tensor(step["reward"].numpy(), dtype=torch.float32)
# TODO: language_text, is_terminal, is_last etc. if lang_key is not None:
langs.append(str(step[lang_key]))
for im_key in image_keys: for im_key in image_keys:
if im_key not in step['observation']: if im_key not in step["observation"]:
continue continue
img = step['observation'][im_key] img = step["observation"][im_key]
img = np.array(img) img = np.array(img)
image_array_dict[im_key].append(img) image_array_dict[im_key].append(img)
@ -152,6 +154,9 @@ def load_from_raw(
else: else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array] ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
if lang_key is not None:
ep_dict["language_instruction"] = langs
ep_dict["observation.state"] = torch.stack(states) # TODO better way ep_dict["observation.state"] = torch.stack(states) # TODO better way
ep_dict["action"] = torch.stack(actions) ep_dict["action"] = torch.stack(actions)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
@ -180,22 +185,21 @@ def to_hf_dataset(data_dict, video) -> Dataset:
features[key] = Image() features[key] = Image()
features["observation.state"] = Sequence( features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value( length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
dtype="float32", id=None)
) )
if "observation.velocity" in data_dict: if "observation.velocity" in data_dict:
features["observation.velocity"] = Sequence( features["observation.velocity"] = Sequence(
length=data_dict["observation.velocity"].shape[1], feature=Value( length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
dtype="float32", id=None)
) )
if "observation.effort" in data_dict: if "observation.effort" in data_dict:
features["observation.effort"] = Sequence( features["observation.effort"] = Sequence(
length=data_dict["observation.effort"].shape[1], feature=Value( length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
dtype="float32", id=None)
) )
if "language_instruction" in data_dict:
features["language_instruction"] = Value(dtype="string", id=None)
features["action"] = Sequence( features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value( length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
dtype="float32", id=None)
) )
features["episode_index"] = Value(dtype="int64", id=None) features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None) features["frame_index"] = Value(dtype="int64", id=None)
@ -231,11 +235,17 @@ def from_raw_to_lerobot_format(
if __name__ == "__main__": if __name__ == "__main__":
# TODO (YL) remove this # NOTE (YL): This mainly serves as a unit test
# austin_buds_dataset_converted_externally_to_rlds is a smaller dataset in
# open x embodiment datasets.
raw_dir = Path("/hdd/tensorflow_datasets/austin_buds_dataset_converted_externally_to_rlds/0.1.0/") raw_dir = Path("/hdd/tensorflow_datasets/austin_buds_dataset_converted_externally_to_rlds/0.1.0/")
videos_dir = Path("/hdd/tmp/") videos_dir = Path("/hdd/tmp/")
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format( hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
raw_dir, videos_dir, fps=5, video=True, episodes=None, raw_dir,
videos_dir,
fps=50,
video=True,
episodes=[2, 3],
) )
print(hf_dataset) print(hf_dataset)
print(episode_data_index) print(episode_data_index)

View File

@ -59,17 +59,35 @@ def unflatten_dict(d, sep="/"):
return outdict return outdict
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): def hf_transform_to_torch(
items_dict: dict[torch.Tensor | None],
lang_tokenizer_name: str = "t5-small",
):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow) """Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to to torch tensors. Importantly, images are converted from PIL, which corresponds to
a channel last representation (h w c) of uint8 type, to a torch image representation a channel last representation (h w c) of uint8 type, to a torch image representation
with channel first (c h w) of float32 type in range [0,1]. with channel first (c h w) of float32 type in range [0,1].
""" """
# tokenize language instructions if it exists
if "language_instruction" in items_dict:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(lang_tokenizer_name)
tokenizer_kwargs = {
"max_length": 64, # NOTE: adjust this value accordingly
"padding": "max_length",
"truncation": True,
"return_tensors": "pt",
}
for key in items_dict: for key in items_dict:
first_item = items_dict[key][0] first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image): if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor() to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]] items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif isinstance(first_item, str):
# convert str to lang embeddings via language tokenizer
items_dict[key] = [tokenizer.encode(x, **tokenizer_kwargs) for x in items_dict[key]]
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item: elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
# video frame will be processed downstream # video frame will be processed downstream
pass pass