Address comments

This commit is contained in:
Cadene 2024-04-16 17:14:40 +00:00
parent b241ea46dd
commit 36d9e885ef
24 changed files with 100 additions and 94 deletions

View File

@ -17,6 +17,17 @@ from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage from PIL import Image as PILImage
def download_and_upload(root, root_tests, dataset_id):
if "pusht" in dataset_id:
download_and_upload_pusht(root, root_tests, dataset_id)
elif "xarm" in dataset_id:
download_and_upload_xarm(root, root_tests, dataset_id)
elif "aloha" in dataset_id:
download_and_upload_aloha(root, root_tests, dataset_id)
else:
raise ValueError(dataset_id)
def download_and_extract_zip(url: str, destination_folder: Path) -> bool: def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
import zipfile import zipfile
@ -87,7 +98,6 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
states = torch.from_numpy(dataset_dict["state"]) states = torch.from_numpy(dataset_dict["state"])
actions = torch.from_numpy(dataset_dict["action"]) actions = torch.from_numpy(dataset_dict["action"])
data_ids_per_episode = {}
ep_dicts = [] ep_dicts = []
id_from = 0 id_from = 0
@ -150,15 +160,11 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
"next.reward": torch.cat([reward[1:], reward[[-1]]]), "next.reward": torch.cat([reward[1:], reward[[-1]]]),
"next.done": torch.cat([done[1:], done[[-1]]]), "next.done": torch.cat([done[1:], done[[-1]]]),
"next.success": torch.cat([success[1:], success[[-1]]]), "next.success": torch.cat([success[1:], success[[-1]]]),
"episode_data_id_from": torch.tensor([id_from] * num_frames), "episode_data_index_from": torch.tensor([id_from] * num_frames),
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames), "episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
} }
ep_dicts.append(ep_dict) ep_dicts.append(ep_dict)
assert isinstance(episode_id, int)
data_ids_per_episode[episode_id] = torch.arange(id_from, id_to, 1)
assert len(data_ids_per_episode[episode_id]) == num_frames
id_from += num_frames id_from += num_frames
data_dict = {} data_dict = {}
@ -190,8 +196,8 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
"next.done": Value(dtype="bool", id=None), "next.done": Value(dtype="bool", id=None),
"next.success": Value(dtype="bool", id=None), "next.success": Value(dtype="bool", id=None),
"index": Value(dtype="int64", id=None), "index": Value(dtype="int64", id=None),
"episode_data_id_from": Value(dtype="int64", id=None), "episode_data_index_from": Value(dtype="int64", id=None),
"episode_data_id_to": Value(dtype="int64", id=None), "episode_data_index_to": Value(dtype="int64", id=None),
} }
features = Features(features) features = Features(features)
dataset = Dataset.from_dict(data_dict, features=features) dataset = Dataset.from_dict(data_dict, features=features)
@ -265,8 +271,8 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
# "next.observation.state": next_state, # "next.observation.state": next_state,
"next.reward": next_reward, "next.reward": next_reward,
"next.done": next_done, "next.done": next_done,
"episode_data_id_from": torch.tensor([id_from] * num_frames), "episode_data_index_from": torch.tensor([id_from] * num_frames),
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames), "episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
} }
ep_dicts.append(ep_dict) ep_dicts.append(ep_dict)
@ -301,8 +307,8 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
"next.done": Value(dtype="bool", id=None), "next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None), #'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None), "index": Value(dtype="int64", id=None),
"episode_data_id_from": Value(dtype="int64", id=None), "episode_data_index_from": Value(dtype="int64", id=None),
"episode_data_id_to": Value(dtype="int64", id=None), "episode_data_index_to": Value(dtype="int64", id=None),
} }
features = Features(features) features = Features(features)
dataset = Dataset.from_dict(data_dict, features=features) dataset = Dataset.from_dict(data_dict, features=features)
@ -390,7 +396,16 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
state = torch.from_numpy(ep["/observations/qpos"][:]) state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:]) action = torch.from_numpy(ep["/action"][:])
ep_dict = { ep_dict = {}
for cam in cameras[dataset_id]:
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c
# image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image]
# ep_dict[f"next.observation.images.{cam}"] = image
ep_dict.update(
{
"observation.state": state, "observation.state": state,
"action": action, "action": action,
"episode_id": torch.tensor([ep_id] * num_frames), "episode_id": torch.tensor([ep_id] * num_frames),
@ -401,15 +416,10 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
# "next.reward": reward, # "next.reward": reward,
"next.done": done, "next.done": done,
# "next.success": success, # "next.success": success,
"episode_data_id_from": torch.tensor([id_from] * num_frames), "episode_data_index_from": torch.tensor([id_from] * num_frames),
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames), "episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
} }
)
for cam in cameras[dataset_id]:
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c
# image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image]
# ep_dict[f"next.observation.images.{cam}"] = image
assert isinstance(ep_id, int) assert isinstance(ep_id, int)
ep_dicts.append(ep_dict) ep_dicts.append(ep_dict)
@ -446,8 +456,8 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
"next.done": Value(dtype="bool", id=None), "next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None), #'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None), "index": Value(dtype="int64", id=None),
"episode_data_id_from": Value(dtype="int64", id=None), "episode_data_index_from": Value(dtype="int64", id=None),
"episode_data_id_to": Value(dtype="int64", id=None), "episode_data_index_to": Value(dtype="int64", id=None),
} }
features = Features(features) features = Features(features)
dataset = Dataset.from_dict(data_dict, features=features) dataset = Dataset.from_dict(data_dict, features=features)
@ -461,23 +471,17 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
if __name__ == "__main__": if __name__ == "__main__":
root = "data" root = "data"
root_tests = "{root_tests}" root_tests = "tests/data"
download_and_upload_pusht(root, root_tests, dataset_id="pusht")
download_and_upload_xarm(root, root_tests, dataset_id="xarm_lift_medium")
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_human")
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_scripted")
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_human")
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_scripted")
dataset_ids = [ dataset_ids = [
"pusht", # "pusht",
"xarm_lift_medium", # "xarm_lift_medium",
"aloha_sim_insertion_human", # "aloha_sim_insertion_human",
"aloha_sim_insertion_scripted", # "aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human", # "aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted", "aloha_sim_transfer_cube_scripted",
] ]
for dataset_id in dataset_ids: for dataset_id in dataset_ids:
download_and_upload(root, root_tests, dataset_id)
# assume stats have been precomputed # assume stats have been precomputed
shutil.copy(f"{root}/{dataset_id}/stats.pth", f"{root_tests}/{dataset_id}/stats.pth") shutil.copy(f"{root}/{dataset_id}/stats.pth", f"{root_tests}/{dataset_id}/stats.pth")

View File

@ -7,9 +7,9 @@ import tqdm
def load_previous_and_future_frames( def load_previous_and_future_frames(
item: dict[torch.Tensor], item: dict[str, torch.Tensor],
data_dict: dict[torch.Tensor], data_dict: dict[str, torch.Tensor],
delta_timestamps: dict[list[float]], delta_timestamps: dict[str, list[float]],
tol: float = 0.04, tol: float = 0.04,
) -> dict[torch.Tensor]: ) -> dict[torch.Tensor]:
""" """
@ -35,12 +35,12 @@ def load_previous_and_future_frames(
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection. - AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection.
""" """
# get indices of the frames associated to the episode, and their timestamps # get indices of the frames associated to the episode, and their timestamps
ep_data_id_from = item["episode_data_id_from"].item() ep_data_id_from = item["episode_data_index_from"].item()
ep_data_id_to = item["episode_data_id_to"].item() ep_data_id_to = item["episode_data_index_to"].item()
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to + 1, 1) ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
# load timestamps # load timestamps
ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from : ep_data_id_to + 1]["timestamp"] ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
# we make the assumption that the timestamps are sorted # we make the assumption that the timestamps are sorted
ep_first_ts = ep_timestamps[0] ep_first_ts = ep_timestamps[0]

View File

@ -215,8 +215,8 @@ def eval_policy(
"timestamp": torch.arange(0, num_frames, 1) / fps, "timestamp": torch.arange(0, num_frames, 1) / fps,
"next.done": dones[ep_id, :num_frames], "next.done": dones[ep_id, :num_frames],
"next.reward": rewards[ep_id, :num_frames].type(torch.float32), "next.reward": rewards[ep_id, :num_frames].type(torch.float32),
"episode_data_id_from": torch.tensor([idx_from] * num_frames), "episode_data_index_from": torch.tensor([idx_from] * num_frames),
"episode_data_id_to": torch.tensor([idx_from + num_frames - 1] * num_frames), "episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
} }
for key in observations: for key in observations:
ep_dict[key] = observations[key][ep_id][:num_frames] ep_dict[key] = observations[key][ep_id][:num_frames]

View File

@ -141,15 +141,15 @@ def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_
online_dataset.data_dict = data_dict online_dataset.data_dict = data_dict
else: else:
# find episode index and data frame indices according to previous episode in online_dataset # find episode index and data frame indices according to previous episode in online_dataset
start_episode = online_dataset.data_dict["episode_id"][-1].item() + 1 start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
start_index = online_dataset.data_dict["index"][-1].item() + 1 start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
def shift_indices(example): def shift_indices(example):
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to # note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
example["episode_id"] += start_episode example["episode_id"] += start_episode
example["index"] += start_index example["index"] += start_index
example["episode_data_id_from"] += start_index example["episode_data_index_from"] += start_index
example["episode_data_id_to"] += start_index example["episode_data_index_to"] += start_index
return example return example
disable_progress_bar() # map has a tqdm progress bar disable_progress_bar() # map has a tqdm progress bar

View File

@ -77,7 +77,7 @@ def render_dataset(dataset, out_dir, max_num_episodes):
# add current frame to list of frames to render # add current frame to list of frames to render
frames[im_key].append(item[im_key]) frames[im_key].append(item[im_key])
end_of_episode = item["index"].item() == item["episode_data_id_to"].item() end_of_episode = item["index"].item() == item["episode_data_index_to"].item()
out_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True)
for im_key in dataset.image_keys: for im_key in dataset.image_keys:

View File

@ -2,6 +2,9 @@
"citation": "", "citation": "",
"description": "", "description": "",
"features": { "features": {
"observation.images.top": {
"_type": "Image"
},
"observation.state": { "observation.state": {
"feature": { "feature": {
"dtype": "float32", "dtype": "float32",
@ -34,17 +37,14 @@
"dtype": "bool", "dtype": "bool",
"_type": "Value" "_type": "Value"
}, },
"episode_data_id_from": { "episode_data_index_from": {
"dtype": "int64", "dtype": "int64",
"_type": "Value" "_type": "Value"
}, },
"episode_data_id_to": { "episode_data_index_to": {
"dtype": "int64", "dtype": "int64",
"_type": "Value" "_type": "Value"
}, },
"observation.images.top": {
"_type": "Image"
},
"index": { "index": {
"dtype": "int64", "dtype": "int64",
"_type": "Value" "_type": "Value"

View File

@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow" "filename": "data-00000-of-00001.arrow"
} }
], ],
"_fingerprint": "05980bca35112ebd", "_fingerprint": "d79cf82ffc86f110",
"_format_columns": null, "_format_columns": null,
"_format_kwargs": {}, "_format_kwargs": {},
"_format_type": "torch", "_format_type": "torch",

View File

@ -2,6 +2,9 @@
"citation": "", "citation": "",
"description": "", "description": "",
"features": { "features": {
"observation.images.top": {
"_type": "Image"
},
"observation.state": { "observation.state": {
"feature": { "feature": {
"dtype": "float32", "dtype": "float32",
@ -34,17 +37,14 @@
"dtype": "bool", "dtype": "bool",
"_type": "Value" "_type": "Value"
}, },
"episode_data_id_from": { "episode_data_index_from": {
"dtype": "int64", "dtype": "int64",
"_type": "Value" "_type": "Value"
}, },
"episode_data_id_to": { "episode_data_index_to": {
"dtype": "int64", "dtype": "int64",
"_type": "Value" "_type": "Value"
}, },
"observation.images.top": {
"_type": "Image"
},
"index": { "index": {
"dtype": "int64", "dtype": "int64",
"_type": "Value" "_type": "Value"

View File

@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow" "filename": "data-00000-of-00001.arrow"
} }
], ],
"_fingerprint": "f3330a7e1d8bc55b", "_fingerprint": "d8e4a817b5449498",
"_format_columns": null, "_format_columns": null,
"_format_kwargs": {}, "_format_kwargs": {},
"_format_type": "torch", "_format_type": "torch",

View File

@ -2,6 +2,9 @@
"citation": "", "citation": "",
"description": "", "description": "",
"features": { "features": {
"observation.images.top": {
"_type": "Image"
},
"observation.state": { "observation.state": {
"feature": { "feature": {
"dtype": "float32", "dtype": "float32",
@ -34,17 +37,14 @@
"dtype": "bool", "dtype": "bool",
"_type": "Value" "_type": "Value"
}, },
"episode_data_id_from": { "episode_data_index_from": {
"dtype": "int64", "dtype": "int64",
"_type": "Value" "_type": "Value"
}, },
"episode_data_id_to": { "episode_data_index_to": {
"dtype": "int64", "dtype": "int64",
"_type": "Value" "_type": "Value"
}, },
"observation.images.top": {
"_type": "Image"
},
"index": { "index": {
"dtype": "int64", "dtype": "int64",
"_type": "Value" "_type": "Value"

View File

@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow" "filename": "data-00000-of-00001.arrow"
} }
], ],
"_fingerprint": "42aa77ffb6863924", "_fingerprint": "f03482befa767127",
"_format_columns": null, "_format_columns": null,
"_format_kwargs": {}, "_format_kwargs": {},
"_format_type": "torch", "_format_type": "torch",

View File

@ -2,6 +2,9 @@
"citation": "", "citation": "",
"description": "", "description": "",
"features": { "features": {
"observation.images.top": {
"_type": "Image"
},
"observation.state": { "observation.state": {
"feature": { "feature": {
"dtype": "float32", "dtype": "float32",
@ -34,17 +37,14 @@
"dtype": "bool", "dtype": "bool",
"_type": "Value" "_type": "Value"
}, },
"episode_data_id_from": { "episode_data_index_from": {
"dtype": "int64", "dtype": "int64",
"_type": "Value" "_type": "Value"
}, },
"episode_data_id_to": { "episode_data_index_to": {
"dtype": "int64", "dtype": "int64",
"_type": "Value" "_type": "Value"
}, },
"observation.images.top": {
"_type": "Image"
},
"index": { "index": {
"dtype": "int64", "dtype": "int64",
"_type": "Value" "_type": "Value"

View File

@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow" "filename": "data-00000-of-00001.arrow"
} }
], ],
"_fingerprint": "43f176a3740fe622", "_fingerprint": "93e03c6320c7d56e",
"_format_columns": null, "_format_columns": null,
"_format_kwargs": {}, "_format_kwargs": {},
"_format_type": "torch", "_format_type": "torch",

View File

@ -45,11 +45,11 @@
"dtype": "bool", "dtype": "bool",
"_type": "Value" "_type": "Value"
}, },
"episode_data_id_from": { "episode_data_index_from": {
"dtype": "int64", "dtype": "int64",
"_type": "Value" "_type": "Value"
}, },
"episode_data_id_to": { "episode_data_index_to": {
"dtype": "int64", "dtype": "int64",
"_type": "Value" "_type": "Value"
}, },

View File

@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow" "filename": "data-00000-of-00001.arrow"
} }
], ],
"_fingerprint": "f7ed966ae18000ae", "_fingerprint": "21bb9a76ed78a475",
"_format_columns": null, "_format_columns": null,
"_format_kwargs": {}, "_format_kwargs": {},
"_format_type": "torch", "_format_type": "torch",

View File

@ -41,11 +41,11 @@
"dtype": "bool", "dtype": "bool",
"_type": "Value" "_type": "Value"
}, },
"episode_data_id_from": { "episode_data_index_from": {
"dtype": "int64", "dtype": "int64",
"_type": "Value" "_type": "Value"
}, },
"episode_data_id_to": { "episode_data_index_to": {
"dtype": "int64", "dtype": "int64",
"_type": "Value" "_type": "Value"
}, },

View File

@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow" "filename": "data-00000-of-00001.arrow"
} }
], ],
"_fingerprint": "7dcd82fc3815bba6", "_fingerprint": "a95cbec45e3bb9d6",
"_format_columns": null, "_format_columns": null,
"_format_kwargs": {}, "_format_kwargs": {},
"_format_type": "torch", "_format_type": "torch",

View File

@ -95,12 +95,14 @@ def test_compute_stats():
""" """
from lerobot.common.datasets.xarm import XarmDataset from lerobot.common.datasets.xarm import XarmDataset
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
# get transform to convert images from uint8 [0,255] to float32 [0,1] # get transform to convert images from uint8 [0,255] to float32 [0,1]
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0) transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
dataset = XarmDataset( dataset = XarmDataset(
dataset_id="xarm_lift_medium", dataset_id="xarm_lift_medium",
root=DATA_DIR,
transform=transform, transform=transform,
) )
@ -115,11 +117,11 @@ def test_compute_stats():
# get all frames from the dataset in the same dtype and range as during compute_stats # get all frames from the dataset in the same dtype and range as during compute_stats
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,
num_workers=16, num_workers=8,
batch_size=len(dataset), batch_size=len(dataset),
shuffle=False, shuffle=False,
) )
data_dict = next(iter(dataloader)) # takes 23 seconds data_dict = next(iter(dataloader))
# compute stats based on all frames from the dataset without any batching # compute stats based on all frames from the dataset without any batching
expected_stats = {} expected_stats = {}
@ -154,8 +156,8 @@ def test_load_previous_and_future_frames_within_tolerance():
data_dict = Dataset.from_dict({ data_dict = Dataset.from_dict({
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4], "index": [0, 1, 2, 3, 4],
"episode_data_id_from": [0, 0, 0, 0, 0], "episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_id_to": [4, 4, 4, 4, 4], "episode_data_index_to": [4, 4, 4, 4, 4],
}) })
data_dict = data_dict.with_format("torch") data_dict = data_dict.with_format("torch")
item = data_dict[2] item = data_dict[2]
@ -170,8 +172,8 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
data_dict = Dataset.from_dict({ data_dict = Dataset.from_dict({
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4], "index": [0, 1, 2, 3, 4],
"episode_data_id_from": [0, 0, 0, 0, 0], "episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_id_to": [4, 4, 4, 4, 4], "episode_data_index_to": [4, 4, 4, 4, 4],
}) })
data_dict = data_dict.with_format("torch") data_dict = data_dict.with_format("torch")
item = data_dict[2] item = data_dict[2]
@ -184,8 +186,8 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
data_dict = Dataset.from_dict({ data_dict = Dataset.from_dict({
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4], "index": [0, 1, 2, 3, 4],
"episode_data_id_from": [0, 0, 0, 0, 0], "episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_id_to": [4, 4, 4, 4, 4], "episode_data_index_to": [4, 4, 4, 4, 4],
}) })
data_dict = data_dict.with_format("torch") data_dict = data_dict.with_format("torch")
item = data_dict[2] item = data_dict[2]