experimental rlds to hg format convert and uploader

Signed-off-by: youliangtan <tan_you_liang@hotmail.com>
This commit is contained in:
youliangtan 2024-06-20 12:04:51 -07:00
parent b72d574891
commit 4ab4950da8
2 changed files with 241 additions and 0 deletions

View File

@ -0,0 +1,239 @@
#!/usr/bin/env python
"""
For https://github.com/google-deepmind/open_x_embodiment (OXE) datasets.
Example:
python lerobot/scripts/push_dataset_to_hub.py \
--raw-dir /hdd/tensorflow_datasets/bridge_dataset/1.0.0/ \
--repo-id youliangtan/sampled_bridge_data_v2 \
--raw-format oxe_rlds \
--episodes 3 4 5 8 9
"""
import gc
import shutil
from pathlib import Path
import h5py
import numpy as np
import torch
import tqdm
from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage
import tensorflow_datasets as tfds
import cv2
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
np.set_printoptions(precision=2)
def get_cameras_keys(obs_keys):
return [key for key in obs_keys if "image" in key]
def tf_to_torch(data):
return torch.from_numpy(data.numpy())
def load_from_raw(
raw_dir: Path,
videos_dir: Path,
fps: int,
video: bool,
episodes: list[int] | None = None
):
"""
Args:
raw_dir (Path): _description_
videos_dir (Path): _description_
fps (int): _description_
video (bool): _description_
episodes (list[int] | None, optional): _description_. Defaults to None.
"""
ds_builder = tfds.builder_from_directory(str(raw_dir))
dataset = ds_builder.as_dataset(split='all')
dataset_info = ds_builder.info
print("dataset_info: ", dataset_info)
image_keys = get_cameras_keys(
dataset_info.features["steps"]["observation"].keys())
print("image_keys: ", image_keys)
ds_length = len(dataset)
dataset = dataset.take(ds_length)
it = iter(dataset)
ep_dicts = []
# if we user specified episodes, skip the ones not in the list
if episodes is not None:
if ds_length == 0:
raise ValueError("No episodes found.")
# convert episodes index to sorted list
episodes = sorted(episodes)
for ep_idx in tqdm.tqdm(range(ds_length)):
episode = next(it)
# if we user specified episodes, skip the ones not in the list
if episodes is not None:
if len(episodes) == 0:
break
if ep_idx == episodes[0]:
# process this episode
print(" selecting episode: ", ep_idx)
episodes.pop(0)
else:
continue # skip
steps = episode['steps']
eps_len = len(steps)
num_frames = eps_len # TODO: check if this is correct
# last step of demonstration is considered done
done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True
states = []
actions = []
ep_dict = {}
image_array_dict = {key: [] for key in image_keys}
###########################################################
# loop through all steps in the episode
for j, step in enumerate(steps):
states.append(tf_to_torch(step['observation']['state']))
actions.append(tf_to_torch(step['action']))
# if "language_text" in step:
# print(" - lang: ", step["language_text"])
for im_key in image_keys:
if im_key not in step['observation']:
continue
img = step['observation'][im_key]
img = np.array(img)
image_array_dict[im_key].append(img)
###########################################################
# loop through all cameras
for im_key in image_keys:
img_key = f"observation.images.{im_key}"
imgs_array = image_array_dict[im_key]
imgs_array = np.array(imgs_array)
if video:
# save png images in temporary directory
tmp_imgs_dir = videos_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = videos_dir / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = torch.stack(states) # TODO better way
ep_dict["action"] = torch.stack(actions)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["next.done"] = done
ep_dicts.append(ep_dict)
data_dict = concatenate_episodes(ep_dicts)
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def to_hf_dataset(data_dict, video) -> Dataset:
features = {}
keys = [key for key in data_dict if "observation.images." in key]
for key in keys:
if video:
features[key] = VideoFrame()
else:
features[key] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(
dtype="float32", id=None)
)
if "observation.velocity" in data_dict:
features["observation.velocity"] = Sequence(
length=data_dict["observation.velocity"].shape[1], feature=Value(
dtype="float32", id=None)
)
if "observation.effort" in data_dict:
features["observation.effort"] = Sequence(
length=data_dict["observation.effort"].shape[1], feature=Value(
dtype="float32", id=None)
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(
dtype="float32", id=None)
)
features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
features["timestamp"] = Value(dtype="float32", id=None)
features["next.done"] = Value(dtype="bool", id=None)
features["index"] = Value(dtype="int64", id=None)
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def from_raw_to_lerobot_format(
raw_dir: Path,
videos_dir: Path,
fps: int | None = None,
video: bool = True,
episodes: list[int] | None = None,
):
"""This is a test impl for rlds conversion"""
if fps is None:
fps = 5
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"fps": fps,
"video": video,
}
return hf_dataset, episode_data_index, info
if __name__ == "__main__":
# TODO (YL) remove this
raw_dir = Path("/hdd/serl/serl_task1_combine_13jun/")
videos_dir = Path("/hdd/serl/tmp/")
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
raw_dir, videos_dir, fps=5, video=True, episodes=None,
)
print(hf_dataset)
print(episode_data_index)
print(info)

View File

@ -66,6 +66,8 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
elif raw_format == "aloha_hdf5":
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
elif raw_format == "oxe_rlds":
from lerobot.common.datasets.push_dataset_to_hub.oxe_rlds_format import from_raw_to_lerobot_format
elif raw_format == "dora_parquet":
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
elif raw_format == "xarm_pkl":