Fix episode_index in aloha datasets, Add velocity and effort to real aloha
This commit is contained in:
parent
e0d851f7ce
commit
52938d2b72
|
@ -18,7 +18,6 @@ Contains utilities to process raw data format of HDF5 files like in: https://git
|
|||
"""
|
||||
|
||||
import gc
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -80,10 +79,8 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
|
||||
for ep_path in tqdm.tqdm(hdf5_files, total=len(hdf5_files)):
|
||||
for ep_idx, ep_path in tqdm.tqdm(enumerate(hdf5_files), total=len(hdf5_files)):
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
ep_idx = int(re.search(r"episode_(\d+)", ep_path.name).group(1))
|
||||
num_frames = ep["/action"].shape[0]
|
||||
|
||||
# last step of demonstration is considered done
|
||||
|
@ -92,6 +89,10 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
if "/observations/qvel" in ep:
|
||||
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||
if "/observations/effort" in ep:
|
||||
effort = torch.from_numpy(ep["/observations/effort"][:])
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
|
@ -132,6 +133,10 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
if "/observations/velocity" in ep:
|
||||
ep_dict["observation.velocity"] = velocity
|
||||
if "/observations/effort" in ep:
|
||||
ep_dict["observation.effort"] = effort
|
||||
ep_dict["action"] = action
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
|
@ -170,6 +175,14 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
|||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.velocity" in data_dict:
|
||||
features["observation.velocity"] = Sequence(
|
||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.effort" in data_dict:
|
||||
features["observation.effort"] = Sequence(
|
||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue