WIP Add aloha_dora_format

This commit is contained in:
Remi Cadene 2024-05-21 21:16:31 +00:00
parent 8460ea6f83
commit b0cb342795
2 changed files with 158 additions and 31 deletions

View File

@ -1,73 +1,196 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains utilities to process raw data format from dora-record
"""
import logging
import re
import shutil
from pathlib import Path
import pandas as pd
from datasets import Dataset
import torch
from datasets import Dataset, Features, Image, Sequence, Value
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame
from lerobot.common.utils.utils import init_logging
def check_format(raw_dir) -> bool:
leader_file = list(raw_dir.glob("*_leader.parquet"))
# TODO(rcadene): remove hardcoding
raw_dir = raw_dir / "018f9c37-c092-72fd-bd83-6f5a5c1b59d2"
assert raw_dir.exists()
if len(leader_file) != 1:
raise ValueError(
f"Issues with leader file in {raw_dir}. Make sure there is one and only one leader file"
)
leader_file = list(raw_dir.glob("*.parquet"))
if len(leader_file) == 0:
raise ValueError(f"Missing parquet files in '{raw_dir}'")
return True
def load_from_raw(raw_dir: Path, out_dir=None, fps=30, video=True, debug=False):
parquet_files = list(raw_dir.glob("*.parquet"))
leader_file = list(raw_dir.glob("*_leader.parquet"))[0]
def load_from_raw(raw_dir: Path, out_dir: Path):
# TODO(rcadene): remove hardcoding
raw_dir = raw_dir / "018f9c37-c092-72fd-bd83-6f5a5c1b59d2"
# Remove leader file from parquet files
parquet_files = [x for x in parquet_files if x != leader_file]
# Load data stream that will be used as reference for the timestamps synchronization
reference_key = "observation.images.cam_right_wrist"
reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet")
reference_df = reference_df[["timestamp_utc", reference_key]]
## Load leader data
data_df = pd.read_parquet(leader_file)
data_df = data_df[["timestamp_utc", leader_file.stem]]
## Merge all data using nearest backward strategy
for data in parquet_files:
df = pd.read_parquet(data)
# Merge all data stream using nearest backward strategy
data_df = reference_df
for path in raw_dir.glob("*.parquet"):
key = path.stem # action or observation.state or ...
if key == reference_key:
continue
df = pd.read_parquet(path)
df = df[["timestamp_utc", key]]
data_df = pd.merge_asof(
data_df,
df[["timestamp_utc", data.stem]],
df,
on="timestamp_utc",
direction="backward",
)
# dora only use arrays, so single values are encapsulated into a list
data_df["episode_index"] = data_df["episode_index"].map(lambda x: x[0])
data_df["frame_index"] = data_df.groupby("episode_index").cumcount()
data_df["index"] = data_df.index
# set 'next.done' to True for the last frame of each episode
data_df["next.done"] = False
data_df.loc[data_df.groupby("episode_index").tail(1).index, "next.done"] = True
# Get the episode index containing for each unique episode index
episode_data_index = data_df["episode_index"].drop_duplicates().reset_index()
episode_data_index["from"] = episode_data_index["index"]
episode_data_index["to"] = episode_data_index["index"].shift(-1)
first_ep_index_df = data_df.groupby("episode_index").agg(start_index=("index", "first")).reset_index()
from_ = first_ep_index_df["start_index"].tolist()
to_ = from_[1:] + [len(data_df)]
episode_data_index = {
"from": from_,
"to": to_,
}
# Remove column index
episode_data_index = episode_data_index.drop(columns=["index"])
data_df["timestamp"] = data_df["timestamp_utc"].map(lambda x: x.timestamp())
# each episode starts with timestamp 0 to match the ones from the video
data_df["timestamp"] = data_df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0])
# episode_data_index to dict
episode_data_index = episode_data_index.to_dict(orient="list")
del data_df["timestamp_utc"]
return data_df, episode_data_index
# Remove rows with a NaN in any column. It can happened during the first frames of an episode,
# because some cameras didnt start recording yet.
data_df = data_df.dropna(axis=1)
# Create symlink to raw videos directory (that needs to be absolute not relative)
# out_dir.mkdir(parents=True, exist_ok=True)
# absolute_videos_dir = (raw_dir / "videos").absolute()
# (out_dir / "videos").symlink_to(absolute_videos_dir)
# TODO(rcadene): remove before merge
(out_dir / "videos").mkdir(parents=True, exist_ok=True)
for from_path in (raw_dir / "videos").glob("*.mp4"):
match = re.search(r"_(\d+)\.mp4$", from_path.name)
if not match:
raise ValueError(from_path.name)
ep_idx = match.group(1)
to_path = out_dir / "videos" / from_path.name.replace(ep_idx, f"{int(ep_idx):06d}")
shutil.copy2(from_path, to_path)
data_dict = {}
for key in data_df:
# is video frame
if "observation.images." in key:
# we need `[0] because dora only use arrays, so single values are encapsulated into a list.
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
data_dict[key] = [video_frame[0] for video_frame in data_df[key].values]
# TODO(rcadene): remove before merge
for item in data_dict[key]:
path = item["path"]
match = re.search(r"_(\d+)\.mp4$", path)
if not match:
raise ValueError(path)
ep_idx = match.group(1)
item["path"] = path.replace(ep_idx, f"{int(ep_idx):06d}")
# is number
elif data_df[key].iloc[0].ndim == 0 or data_df[key].iloc[0].shape[0] == 1:
data_dict[key] = torch.from_numpy(data_df[key].values)
# is vector
elif data_df[key].iloc[0].shape[0] > 1:
data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in data_df[key].values])
else:
raise ValueError(key)
return data_dict, episode_data_index
def to_hf_dataset(df, video) -> Dataset:
hf_dataset = Dataset.from_pandas(df)
def to_hf_dataset(data_dict, video) -> Dataset:
features = {}
keys = [key for key in data_dict if "observation.images." in key]
for key in keys:
if video:
features[key] = VideoFrame()
else:
features[key] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.velocity" in data_dict:
features["observation.velocity"] = Sequence(
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.effort" in data_dict:
features["observation.effort"] = Sequence(
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
)
features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
features["timestamp"] = Value(dtype="float32", id=None)
features["next.done"] = Value(dtype="bool", id=None)
features["index"] = Value(dtype="int64", id=None)
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
init_logging()
if debug:
logging.warning("debug=True not implemented. Falling back to debug=False.")
# sanity check
check_format(raw_dir)
if fps is None:
fps = 30
else:
raise NotImplementedError()
data_df, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
if not video:
raise NotImplementedError()
data_df, episode_data_index = load_from_raw(raw_dir, out_dir)
hf_dataset = to_hf_dataset(data_df, video)
info = {

View File

@ -84,10 +84,14 @@ def get_from_raw_to_lerobot_format_fn(raw_format):
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
elif raw_format == "aloha_hdf5":
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
elif raw_format == "aloha_dora":
from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format
elif raw_format == "xarm_pkl":
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
else:
raise ValueError(raw_format)
raise ValueError(
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
)
return from_raw_to_lerobot_format