WIP Add aloha_dora_format
This commit is contained in:
parent
8460ea6f83
commit
b0cb342795
|
@ -1,73 +1,196 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Contains utilities to process raw data format from dora-record
|
Contains utilities to process raw data format from dora-record
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from datasets import Dataset
|
import torch
|
||||||
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
|
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
hf_transform_to_torch,
|
||||||
|
)
|
||||||
|
from lerobot.common.datasets.video_utils import VideoFrame
|
||||||
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
def check_format(raw_dir) -> bool:
|
def check_format(raw_dir) -> bool:
|
||||||
leader_file = list(raw_dir.glob("*_leader.parquet"))
|
# TODO(rcadene): remove hardcoding
|
||||||
|
raw_dir = raw_dir / "018f9c37-c092-72fd-bd83-6f5a5c1b59d2"
|
||||||
|
assert raw_dir.exists()
|
||||||
|
|
||||||
if len(leader_file) != 1:
|
leader_file = list(raw_dir.glob("*.parquet"))
|
||||||
raise ValueError(
|
if len(leader_file) == 0:
|
||||||
f"Issues with leader file in {raw_dir}. Make sure there is one and only one leader file"
|
raise ValueError(f"Missing parquet files in '{raw_dir}'")
|
||||||
)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir: Path, out_dir=None, fps=30, video=True, debug=False):
|
def load_from_raw(raw_dir: Path, out_dir: Path):
|
||||||
parquet_files = list(raw_dir.glob("*.parquet"))
|
# TODO(rcadene): remove hardcoding
|
||||||
leader_file = list(raw_dir.glob("*_leader.parquet"))[0]
|
raw_dir = raw_dir / "018f9c37-c092-72fd-bd83-6f5a5c1b59d2"
|
||||||
|
|
||||||
# Remove leader file from parquet files
|
# Load data stream that will be used as reference for the timestamps synchronization
|
||||||
parquet_files = [x for x in parquet_files if x != leader_file]
|
reference_key = "observation.images.cam_right_wrist"
|
||||||
|
reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet")
|
||||||
|
reference_df = reference_df[["timestamp_utc", reference_key]]
|
||||||
|
|
||||||
## Load leader data
|
# Merge all data stream using nearest backward strategy
|
||||||
data_df = pd.read_parquet(leader_file)
|
data_df = reference_df
|
||||||
data_df = data_df[["timestamp_utc", leader_file.stem]]
|
for path in raw_dir.glob("*.parquet"):
|
||||||
|
key = path.stem # action or observation.state or ...
|
||||||
## Merge all data using nearest backward strategy
|
if key == reference_key:
|
||||||
for data in parquet_files:
|
continue
|
||||||
df = pd.read_parquet(data)
|
df = pd.read_parquet(path)
|
||||||
|
df = df[["timestamp_utc", key]]
|
||||||
data_df = pd.merge_asof(
|
data_df = pd.merge_asof(
|
||||||
data_df,
|
data_df,
|
||||||
df[["timestamp_utc", data.stem]],
|
df,
|
||||||
on="timestamp_utc",
|
on="timestamp_utc",
|
||||||
direction="backward",
|
direction="backward",
|
||||||
)
|
)
|
||||||
|
# dora only use arrays, so single values are encapsulated into a list
|
||||||
data_df["episode_index"] = data_df["episode_index"].map(lambda x: x[0])
|
data_df["episode_index"] = data_df["episode_index"].map(lambda x: x[0])
|
||||||
|
data_df["frame_index"] = data_df.groupby("episode_index").cumcount()
|
||||||
|
data_df["index"] = data_df.index
|
||||||
|
|
||||||
|
# set 'next.done' to True for the last frame of each episode
|
||||||
|
data_df["next.done"] = False
|
||||||
|
data_df.loc[data_df.groupby("episode_index").tail(1).index, "next.done"] = True
|
||||||
|
|
||||||
# Get the episode index containing for each unique episode index
|
# Get the episode index containing for each unique episode index
|
||||||
episode_data_index = data_df["episode_index"].drop_duplicates().reset_index()
|
first_ep_index_df = data_df.groupby("episode_index").agg(start_index=("index", "first")).reset_index()
|
||||||
episode_data_index["from"] = episode_data_index["index"]
|
from_ = first_ep_index_df["start_index"].tolist()
|
||||||
episode_data_index["to"] = episode_data_index["index"].shift(-1)
|
to_ = from_[1:] + [len(data_df)]
|
||||||
|
episode_data_index = {
|
||||||
|
"from": from_,
|
||||||
|
"to": to_,
|
||||||
|
}
|
||||||
|
|
||||||
# Remove column index
|
data_df["timestamp"] = data_df["timestamp_utc"].map(lambda x: x.timestamp())
|
||||||
episode_data_index = episode_data_index.drop(columns=["index"])
|
# each episode starts with timestamp 0 to match the ones from the video
|
||||||
|
data_df["timestamp"] = data_df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0])
|
||||||
|
|
||||||
# episode_data_index to dict
|
del data_df["timestamp_utc"]
|
||||||
episode_data_index = episode_data_index.to_dict(orient="list")
|
|
||||||
|
|
||||||
return data_df, episode_data_index
|
# Remove rows with a NaN in any column. It can happened during the first frames of an episode,
|
||||||
|
# because some cameras didnt start recording yet.
|
||||||
|
data_df = data_df.dropna(axis=1)
|
||||||
|
|
||||||
|
# Create symlink to raw videos directory (that needs to be absolute not relative)
|
||||||
|
# out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# absolute_videos_dir = (raw_dir / "videos").absolute()
|
||||||
|
# (out_dir / "videos").symlink_to(absolute_videos_dir)
|
||||||
|
|
||||||
|
# TODO(rcadene): remove before merge
|
||||||
|
(out_dir / "videos").mkdir(parents=True, exist_ok=True)
|
||||||
|
for from_path in (raw_dir / "videos").glob("*.mp4"):
|
||||||
|
match = re.search(r"_(\d+)\.mp4$", from_path.name)
|
||||||
|
if not match:
|
||||||
|
raise ValueError(from_path.name)
|
||||||
|
ep_idx = match.group(1)
|
||||||
|
to_path = out_dir / "videos" / from_path.name.replace(ep_idx, f"{int(ep_idx):06d}")
|
||||||
|
shutil.copy2(from_path, to_path)
|
||||||
|
|
||||||
|
data_dict = {}
|
||||||
|
for key in data_df:
|
||||||
|
# is video frame
|
||||||
|
if "observation.images." in key:
|
||||||
|
# we need `[0] because dora only use arrays, so single values are encapsulated into a list.
|
||||||
|
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
|
||||||
|
data_dict[key] = [video_frame[0] for video_frame in data_df[key].values]
|
||||||
|
|
||||||
|
# TODO(rcadene): remove before merge
|
||||||
|
for item in data_dict[key]:
|
||||||
|
path = item["path"]
|
||||||
|
match = re.search(r"_(\d+)\.mp4$", path)
|
||||||
|
if not match:
|
||||||
|
raise ValueError(path)
|
||||||
|
ep_idx = match.group(1)
|
||||||
|
item["path"] = path.replace(ep_idx, f"{int(ep_idx):06d}")
|
||||||
|
# is number
|
||||||
|
elif data_df[key].iloc[0].ndim == 0 or data_df[key].iloc[0].shape[0] == 1:
|
||||||
|
data_dict[key] = torch.from_numpy(data_df[key].values)
|
||||||
|
# is vector
|
||||||
|
elif data_df[key].iloc[0].shape[0] > 1:
|
||||||
|
data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in data_df[key].values])
|
||||||
|
else:
|
||||||
|
raise ValueError(key)
|
||||||
|
|
||||||
|
return data_dict, episode_data_index
|
||||||
|
|
||||||
|
|
||||||
def to_hf_dataset(df, video) -> Dataset:
|
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||||
hf_dataset = Dataset.from_pandas(df)
|
features = {}
|
||||||
|
|
||||||
|
keys = [key for key in data_dict if "observation.images." in key]
|
||||||
|
for key in keys:
|
||||||
|
if video:
|
||||||
|
features[key] = VideoFrame()
|
||||||
|
else:
|
||||||
|
features[key] = Image()
|
||||||
|
|
||||||
|
features["observation.state"] = Sequence(
|
||||||
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
|
if "observation.velocity" in data_dict:
|
||||||
|
features["observation.velocity"] = Sequence(
|
||||||
|
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
|
if "observation.effort" in data_dict:
|
||||||
|
features["observation.effort"] = Sequence(
|
||||||
|
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
|
features["action"] = Sequence(
|
||||||
|
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
|
features["episode_index"] = Value(dtype="int64", id=None)
|
||||||
|
features["frame_index"] = Value(dtype="int64", id=None)
|
||||||
|
features["timestamp"] = Value(dtype="float32", id=None)
|
||||||
|
features["next.done"] = Value(dtype="bool", id=None)
|
||||||
|
features["index"] = Value(dtype="int64", id=None)
|
||||||
|
|
||||||
|
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||||
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||||
|
init_logging()
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
logging.warning("debug=True not implemented. Falling back to debug=False.")
|
||||||
|
|
||||||
# sanity check
|
# sanity check
|
||||||
check_format(raw_dir)
|
check_format(raw_dir)
|
||||||
|
|
||||||
if fps is None:
|
if fps is None:
|
||||||
fps = 30
|
fps = 30
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
data_df, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
if not video:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
data_df, episode_data_index = load_from_raw(raw_dir, out_dir)
|
||||||
hf_dataset = to_hf_dataset(data_df, video)
|
hf_dataset = to_hf_dataset(data_df, video)
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
|
|
|
@ -84,10 +84,14 @@ def get_from_raw_to_lerobot_format_fn(raw_format):
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||||
elif raw_format == "aloha_hdf5":
|
elif raw_format == "aloha_hdf5":
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||||
|
elif raw_format == "aloha_dora":
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format
|
||||||
elif raw_format == "xarm_pkl":
|
elif raw_format == "xarm_pkl":
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
||||||
else:
|
else:
|
||||||
raise ValueError(raw_format)
|
raise ValueError(
|
||||||
|
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
|
||||||
|
)
|
||||||
|
|
||||||
return from_raw_to_lerobot_format
|
return from_raw_to_lerobot_format
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue