add lang tokenizer
Signed-off-by: youliangtan <tan_you_liang@hotmail.com>
This commit is contained in:
parent
61e51c9fe4
commit
a644084f98
|
@ -43,6 +43,10 @@ def get_stats_einops_patterns(dataset, num_workers=0):
|
||||||
# sanity check that tensors are not float64
|
# sanity check that tensors are not float64
|
||||||
assert batch[key].dtype != torch.float64
|
assert batch[key].dtype != torch.float64
|
||||||
|
|
||||||
|
# NOTE: skip language_instruction embedding in stats computation
|
||||||
|
if key == "language_instruction":
|
||||||
|
continue
|
||||||
|
|
||||||
if isinstance(feats_type, (VideoFrame, Image)):
|
if isinstance(feats_type, (VideoFrame, Image)):
|
||||||
# sanity check that images are channel first
|
# sanity check that images are channel first
|
||||||
_, c, h, w = batch[key].shape
|
_, c, h, w = batch[key].shape
|
||||||
|
|
|
@ -8,21 +8,22 @@ Example:
|
||||||
--raw-dir /hdd/tensorflow_datasets/bridge_dataset/1.0.0/ \
|
--raw-dir /hdd/tensorflow_datasets/bridge_dataset/1.0.0/ \
|
||||||
--repo-id youliangtan/sampled_bridge_data_v2 \
|
--repo-id youliangtan/sampled_bridge_data_v2 \
|
||||||
--raw-format oxe_rlds \
|
--raw-format oxe_rlds \
|
||||||
--episodes 3 4 5 8 9
|
--episodes 3 4 5 8 9 \
|
||||||
|
--fps 5
|
||||||
|
|
||||||
|
Exact dataset fps is specified in:
|
||||||
|
https://docs.google.com/spreadsheets/d/1rPBD77tk60AEIGZrGSODwyyzs5FgCU9Uz3h-3_t2A9g/edit?gid=0#gid=0&range=R:R
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import gc
|
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import h5py
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import tensorflow_datasets as tfds
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
import tensorflow_datasets as tfds
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
@ -42,13 +43,7 @@ def tf_to_torch(data):
|
||||||
return torch.from_numpy(data.numpy())
|
return torch.from_numpy(data.numpy())
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(
|
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||||
raw_dir: Path,
|
|
||||||
videos_dir: Path,
|
|
||||||
fps: int,
|
|
||||||
video: bool,
|
|
||||||
episodes: list[int] | None = None
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
raw_dir (Path): _description_
|
raw_dir (Path): _description_
|
||||||
|
@ -58,13 +53,19 @@ def load_from_raw(
|
||||||
episodes (list[int] | None, optional): _description_. Defaults to None.
|
episodes (list[int] | None, optional): _description_. Defaults to None.
|
||||||
"""
|
"""
|
||||||
ds_builder = tfds.builder_from_directory(str(raw_dir))
|
ds_builder = tfds.builder_from_directory(str(raw_dir))
|
||||||
dataset = ds_builder.as_dataset(split='all')
|
dataset = ds_builder.as_dataset(split="all")
|
||||||
dataset_info = ds_builder.info
|
dataset_info = ds_builder.info
|
||||||
print("dataset_info: ", dataset_info)
|
print("dataset_info: ", dataset_info)
|
||||||
|
|
||||||
image_keys = get_cameras_keys(
|
image_keys = get_cameras_keys(dataset_info.features["steps"]["observation"].keys())
|
||||||
dataset_info.features["steps"]["observation"].keys())
|
|
||||||
print("image_keys: ", image_keys)
|
# check if there's a 'tfds.features.Text' in step, only take 1 lang instruction
|
||||||
|
lang_key = [
|
||||||
|
key for key, value in dataset_info.features["steps"].items() if isinstance(value, tfds.features.Text)
|
||||||
|
]
|
||||||
|
lang_key = None if len(lang_key) == 0 else lang_key[0]
|
||||||
|
print(" - image_keys: ", image_keys)
|
||||||
|
print(" - lang_key: ", lang_key)
|
||||||
|
|
||||||
ds_length = len(dataset)
|
ds_length = len(dataset)
|
||||||
dataset = dataset.take(ds_length)
|
dataset = dataset.take(ds_length)
|
||||||
|
@ -88,40 +89,41 @@ def load_from_raw(
|
||||||
break
|
break
|
||||||
if ep_idx == episodes[0]:
|
if ep_idx == episodes[0]:
|
||||||
# process this episode
|
# process this episode
|
||||||
print(" selecting episode: ", ep_idx)
|
print(" selecting episode idx: ", ep_idx)
|
||||||
episodes.pop(0)
|
episodes.pop(0)
|
||||||
else:
|
else:
|
||||||
continue # skip
|
continue # skip
|
||||||
|
|
||||||
steps = episode['steps']
|
steps = episode["steps"]
|
||||||
eps_len = len(steps)
|
num_frames = len(steps)
|
||||||
num_frames = eps_len # TODO: check if this is correct
|
|
||||||
|
|
||||||
# last step of demonstration is considered done
|
# last step of demonstration is considered done
|
||||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||||
done[-1] = True
|
done[-1] = True
|
||||||
|
|
||||||
states = []
|
states = []
|
||||||
actions = []
|
actions = [] # TODO(YL): some actions can be a featuredict
|
||||||
rewards = torch.zeros(num_frames, dtype=torch.float32)
|
rewards = torch.zeros(num_frames, dtype=torch.float32)
|
||||||
ep_dict = {}
|
ep_dict = {}
|
||||||
|
langs = []
|
||||||
|
|
||||||
image_array_dict = {key: [] for key in image_keys}
|
image_array_dict = {key: [] for key in image_keys}
|
||||||
|
|
||||||
###########################################################
|
###########################################################
|
||||||
# loop through all steps in the episode
|
# loop through all steps in the episode
|
||||||
for j, step in enumerate(steps):
|
for j, step in enumerate(steps):
|
||||||
states.append(tf_to_torch(step['observation']['state']))
|
states.append(tf_to_torch(step["observation"]["state"]))
|
||||||
actions.append(tf_to_torch(step['action']))
|
actions.append(tf_to_torch(step["action"]))
|
||||||
rewards[j] = torch.tensor(step['reward'].numpy(), dtype=torch.float32)
|
rewards[j] = torch.tensor(step["reward"].numpy(), dtype=torch.float32)
|
||||||
|
|
||||||
# TODO: language_text, is_terminal, is_last etc.
|
if lang_key is not None:
|
||||||
|
langs.append(str(step[lang_key]))
|
||||||
|
|
||||||
for im_key in image_keys:
|
for im_key in image_keys:
|
||||||
if im_key not in step['observation']:
|
if im_key not in step["observation"]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
img = step['observation'][im_key]
|
img = step["observation"][im_key]
|
||||||
img = np.array(img)
|
img = np.array(img)
|
||||||
image_array_dict[im_key].append(img)
|
image_array_dict[im_key].append(img)
|
||||||
|
|
||||||
|
@ -152,6 +154,9 @@ def load_from_raw(
|
||||||
else:
|
else:
|
||||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||||
|
|
||||||
|
if lang_key is not None:
|
||||||
|
ep_dict["language_instruction"] = langs
|
||||||
|
|
||||||
ep_dict["observation.state"] = torch.stack(states) # TODO better way
|
ep_dict["observation.state"] = torch.stack(states) # TODO better way
|
||||||
ep_dict["action"] = torch.stack(actions)
|
ep_dict["action"] = torch.stack(actions)
|
||||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||||
|
@ -180,22 +185,21 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||||
features[key] = Image()
|
features[key] = Image()
|
||||||
|
|
||||||
features["observation.state"] = Sequence(
|
features["observation.state"] = Sequence(
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
dtype="float32", id=None)
|
|
||||||
)
|
)
|
||||||
if "observation.velocity" in data_dict:
|
if "observation.velocity" in data_dict:
|
||||||
features["observation.velocity"] = Sequence(
|
features["observation.velocity"] = Sequence(
|
||||||
length=data_dict["observation.velocity"].shape[1], feature=Value(
|
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
dtype="float32", id=None)
|
|
||||||
)
|
)
|
||||||
if "observation.effort" in data_dict:
|
if "observation.effort" in data_dict:
|
||||||
features["observation.effort"] = Sequence(
|
features["observation.effort"] = Sequence(
|
||||||
length=data_dict["observation.effort"].shape[1], feature=Value(
|
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
dtype="float32", id=None)
|
|
||||||
)
|
)
|
||||||
|
if "language_instruction" in data_dict:
|
||||||
|
features["language_instruction"] = Value(dtype="string", id=None)
|
||||||
|
|
||||||
features["action"] = Sequence(
|
features["action"] = Sequence(
|
||||||
length=data_dict["action"].shape[1], feature=Value(
|
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
dtype="float32", id=None)
|
|
||||||
)
|
)
|
||||||
features["episode_index"] = Value(dtype="int64", id=None)
|
features["episode_index"] = Value(dtype="int64", id=None)
|
||||||
features["frame_index"] = Value(dtype="int64", id=None)
|
features["frame_index"] = Value(dtype="int64", id=None)
|
||||||
|
@ -231,11 +235,17 @@ def from_raw_to_lerobot_format(
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# TODO (YL) remove this
|
# NOTE (YL): This mainly serves as a unit test
|
||||||
|
# austin_buds_dataset_converted_externally_to_rlds is a smaller dataset in
|
||||||
|
# open x embodiment datasets.
|
||||||
raw_dir = Path("/hdd/tensorflow_datasets/austin_buds_dataset_converted_externally_to_rlds/0.1.0/")
|
raw_dir = Path("/hdd/tensorflow_datasets/austin_buds_dataset_converted_externally_to_rlds/0.1.0/")
|
||||||
videos_dir = Path("/hdd/tmp/")
|
videos_dir = Path("/hdd/tmp/")
|
||||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||||
raw_dir, videos_dir, fps=5, video=True, episodes=None,
|
raw_dir,
|
||||||
|
videos_dir,
|
||||||
|
fps=50,
|
||||||
|
video=True,
|
||||||
|
episodes=[2, 3],
|
||||||
)
|
)
|
||||||
print(hf_dataset)
|
print(hf_dataset)
|
||||||
print(episode_data_index)
|
print(episode_data_index)
|
||||||
|
|
|
@ -59,17 +59,35 @@ def unflatten_dict(d, sep="/"):
|
||||||
return outdict
|
return outdict
|
||||||
|
|
||||||
|
|
||||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
def hf_transform_to_torch(
|
||||||
|
items_dict: dict[torch.Tensor | None],
|
||||||
|
lang_tokenizer_name: str = "t5-small",
|
||||||
|
):
|
||||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||||
a channel last representation (h w c) of uint8 type, to a torch image representation
|
a channel last representation (h w c) of uint8 type, to a torch image representation
|
||||||
with channel first (c h w) of float32 type in range [0,1].
|
with channel first (c h w) of float32 type in range [0,1].
|
||||||
"""
|
"""
|
||||||
|
# tokenize language instructions if it exists
|
||||||
|
if "language_instruction" in items_dict:
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(lang_tokenizer_name)
|
||||||
|
tokenizer_kwargs = {
|
||||||
|
"max_length": 64, # NOTE: adjust this value accordingly
|
||||||
|
"padding": "max_length",
|
||||||
|
"truncation": True,
|
||||||
|
"return_tensors": "pt",
|
||||||
|
}
|
||||||
|
|
||||||
for key in items_dict:
|
for key in items_dict:
|
||||||
first_item = items_dict[key][0]
|
first_item = items_dict[key][0]
|
||||||
if isinstance(first_item, PILImage.Image):
|
if isinstance(first_item, PILImage.Image):
|
||||||
to_tensor = transforms.ToTensor()
|
to_tensor = transforms.ToTensor()
|
||||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||||
|
elif isinstance(first_item, str):
|
||||||
|
# convert str to lang embeddings via language tokenizer
|
||||||
|
items_dict[key] = [tokenizer.encode(x, **tokenizer_kwargs) for x in items_dict[key]]
|
||||||
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
|
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
|
||||||
# video frame will be processed downstream
|
# video frame will be processed downstream
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue