Address comments
This commit is contained in:
parent
b241ea46dd
commit
36d9e885ef
|
@ -17,6 +17,17 @@ from datasets import Dataset, Features, Image, Sequence, Value
|
|||
from PIL import Image as PILImage
|
||||
|
||||
|
||||
def download_and_upload(root, root_tests, dataset_id):
|
||||
if "pusht" in dataset_id:
|
||||
download_and_upload_pusht(root, root_tests, dataset_id)
|
||||
elif "xarm" in dataset_id:
|
||||
download_and_upload_xarm(root, root_tests, dataset_id)
|
||||
elif "aloha" in dataset_id:
|
||||
download_and_upload_aloha(root, root_tests, dataset_id)
|
||||
else:
|
||||
raise ValueError(dataset_id)
|
||||
|
||||
|
||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
import zipfile
|
||||
|
||||
|
@ -87,7 +98,6 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
|||
states = torch.from_numpy(dataset_dict["state"])
|
||||
actions = torch.from_numpy(dataset_dict["action"])
|
||||
|
||||
data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
id_from = 0
|
||||
|
@ -150,15 +160,11 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
|||
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
||||
"next.done": torch.cat([done[1:], done[[-1]]]),
|
||||
"next.success": torch.cat([success[1:], success[[-1]]]),
|
||||
"episode_data_id_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
|
||||
"episode_data_index_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
assert isinstance(episode_id, int)
|
||||
data_ids_per_episode[episode_id] = torch.arange(id_from, id_to, 1)
|
||||
assert len(data_ids_per_episode[episode_id]) == num_frames
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
data_dict = {}
|
||||
|
@ -190,8 +196,8 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
|||
"next.done": Value(dtype="bool", id=None),
|
||||
"next.success": Value(dtype="bool", id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
"episode_data_id_from": Value(dtype="int64", id=None),
|
||||
"episode_data_id_to": Value(dtype="int64", id=None),
|
||||
"episode_data_index_from": Value(dtype="int64", id=None),
|
||||
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
dataset = Dataset.from_dict(data_dict, features=features)
|
||||
|
@ -265,8 +271,8 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
|||
# "next.observation.state": next_state,
|
||||
"next.reward": next_reward,
|
||||
"next.done": next_done,
|
||||
"episode_data_id_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
|
||||
"episode_data_index_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
|
@ -301,8 +307,8 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
|||
"next.done": Value(dtype="bool", id=None),
|
||||
#'next.success': Value(dtype='bool', id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
"episode_data_id_from": Value(dtype="int64", id=None),
|
||||
"episode_data_id_to": Value(dtype="int64", id=None),
|
||||
"episode_data_index_from": Value(dtype="int64", id=None),
|
||||
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
dataset = Dataset.from_dict(data_dict, features=features)
|
||||
|
@ -390,7 +396,16 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
|||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
ep_dict = {
|
||||
ep_dict = {}
|
||||
|
||||
for cam in cameras[dataset_id]:
|
||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c
|
||||
# image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
||||
ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image]
|
||||
# ep_dict[f"next.observation.images.{cam}"] = image
|
||||
|
||||
ep_dict.update(
|
||||
{
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode_id": torch.tensor([ep_id] * num_frames),
|
||||
|
@ -401,15 +416,10 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
|||
# "next.reward": reward,
|
||||
"next.done": done,
|
||||
# "next.success": success,
|
||||
"episode_data_id_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
|
||||
"episode_data_index_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
|
||||
}
|
||||
|
||||
for cam in cameras[dataset_id]:
|
||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c
|
||||
# image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
||||
ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image]
|
||||
# ep_dict[f"next.observation.images.{cam}"] = image
|
||||
)
|
||||
|
||||
assert isinstance(ep_id, int)
|
||||
ep_dicts.append(ep_dict)
|
||||
|
@ -446,8 +456,8 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
|||
"next.done": Value(dtype="bool", id=None),
|
||||
#'next.success': Value(dtype='bool', id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
"episode_data_id_from": Value(dtype="int64", id=None),
|
||||
"episode_data_id_to": Value(dtype="int64", id=None),
|
||||
"episode_data_index_from": Value(dtype="int64", id=None),
|
||||
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
dataset = Dataset.from_dict(data_dict, features=features)
|
||||
|
@ -461,23 +471,17 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
|||
|
||||
if __name__ == "__main__":
|
||||
root = "data"
|
||||
root_tests = "{root_tests}"
|
||||
|
||||
download_and_upload_pusht(root, root_tests, dataset_id="pusht")
|
||||
download_and_upload_xarm(root, root_tests, dataset_id="xarm_lift_medium")
|
||||
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_human")
|
||||
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_scripted")
|
||||
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_human")
|
||||
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_scripted")
|
||||
root_tests = "tests/data"
|
||||
|
||||
dataset_ids = [
|
||||
"pusht",
|
||||
"xarm_lift_medium",
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
# "pusht",
|
||||
# "xarm_lift_medium",
|
||||
# "aloha_sim_insertion_human",
|
||||
# "aloha_sim_insertion_scripted",
|
||||
# "aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
]
|
||||
for dataset_id in dataset_ids:
|
||||
download_and_upload(root, root_tests, dataset_id)
|
||||
# assume stats have been precomputed
|
||||
shutil.copy(f"{root}/{dataset_id}/stats.pth", f"{root_tests}/{dataset_id}/stats.pth")
|
||||
|
|
|
@ -7,9 +7,9 @@ import tqdm
|
|||
|
||||
|
||||
def load_previous_and_future_frames(
|
||||
item: dict[torch.Tensor],
|
||||
data_dict: dict[torch.Tensor],
|
||||
delta_timestamps: dict[list[float]],
|
||||
item: dict[str, torch.Tensor],
|
||||
data_dict: dict[str, torch.Tensor],
|
||||
delta_timestamps: dict[str, list[float]],
|
||||
tol: float = 0.04,
|
||||
) -> dict[torch.Tensor]:
|
||||
"""
|
||||
|
@ -35,12 +35,12 @@ def load_previous_and_future_frames(
|
|||
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection.
|
||||
"""
|
||||
# get indices of the frames associated to the episode, and their timestamps
|
||||
ep_data_id_from = item["episode_data_id_from"].item()
|
||||
ep_data_id_to = item["episode_data_id_to"].item()
|
||||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to + 1, 1)
|
||||
ep_data_id_from = item["episode_data_index_from"].item()
|
||||
ep_data_id_to = item["episode_data_index_to"].item()
|
||||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
||||
|
||||
# load timestamps
|
||||
ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from : ep_data_id_to + 1]["timestamp"]
|
||||
ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
||||
|
||||
# we make the assumption that the timestamps are sorted
|
||||
ep_first_ts = ep_timestamps[0]
|
||||
|
|
|
@ -215,8 +215,8 @@ def eval_policy(
|
|||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
"next.done": dones[ep_id, :num_frames],
|
||||
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
||||
"episode_data_id_from": torch.tensor([idx_from] * num_frames),
|
||||
"episode_data_id_to": torch.tensor([idx_from + num_frames - 1] * num_frames),
|
||||
"episode_data_index_from": torch.tensor([idx_from] * num_frames),
|
||||
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
|
||||
}
|
||||
for key in observations:
|
||||
ep_dict[key] = observations[key][ep_id][:num_frames]
|
||||
|
|
|
@ -141,15 +141,15 @@ def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_
|
|||
online_dataset.data_dict = data_dict
|
||||
else:
|
||||
# find episode index and data frame indices according to previous episode in online_dataset
|
||||
start_episode = online_dataset.data_dict["episode_id"][-1].item() + 1
|
||||
start_index = online_dataset.data_dict["index"][-1].item() + 1
|
||||
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
|
||||
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
|
||||
|
||||
def shift_indices(example):
|
||||
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
|
||||
example["episode_id"] += start_episode
|
||||
example["index"] += start_index
|
||||
example["episode_data_id_from"] += start_index
|
||||
example["episode_data_id_to"] += start_index
|
||||
example["episode_data_index_from"] += start_index
|
||||
example["episode_data_index_to"] += start_index
|
||||
return example
|
||||
|
||||
disable_progress_bar() # map has a tqdm progress bar
|
||||
|
|
|
@ -77,7 +77,7 @@ def render_dataset(dataset, out_dir, max_num_episodes):
|
|||
# add current frame to list of frames to render
|
||||
frames[im_key].append(item[im_key])
|
||||
|
||||
end_of_episode = item["index"].item() == item["episode_data_id_to"].item()
|
||||
end_of_episode = item["index"].item() == item["episode_data_index_to"].item()
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
for im_key in dataset.image_keys:
|
||||
|
|
Binary file not shown.
|
@ -2,6 +2,9 @@
|
|||
"citation": "",
|
||||
"description": "",
|
||||
"features": {
|
||||
"observation.images.top": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"observation.state": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
|
@ -34,17 +37,14 @@
|
|||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_from": {
|
||||
"episode_data_index_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_to": {
|
||||
"episode_data_index_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"observation.images.top": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "05980bca35112ebd",
|
||||
"_fingerprint": "d79cf82ffc86f110",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
|
|
Binary file not shown.
|
@ -2,6 +2,9 @@
|
|||
"citation": "",
|
||||
"description": "",
|
||||
"features": {
|
||||
"observation.images.top": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"observation.state": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
|
@ -34,17 +37,14 @@
|
|||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_from": {
|
||||
"episode_data_index_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_to": {
|
||||
"episode_data_index_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"observation.images.top": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "f3330a7e1d8bc55b",
|
||||
"_fingerprint": "d8e4a817b5449498",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
|
|
Binary file not shown.
|
@ -2,6 +2,9 @@
|
|||
"citation": "",
|
||||
"description": "",
|
||||
"features": {
|
||||
"observation.images.top": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"observation.state": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
|
@ -34,17 +37,14 @@
|
|||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_from": {
|
||||
"episode_data_index_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_to": {
|
||||
"episode_data_index_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"observation.images.top": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "42aa77ffb6863924",
|
||||
"_fingerprint": "f03482befa767127",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
|
|
Binary file not shown.
|
@ -2,6 +2,9 @@
|
|||
"citation": "",
|
||||
"description": "",
|
||||
"features": {
|
||||
"observation.images.top": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"observation.state": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
|
@ -34,17 +37,14 @@
|
|||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_from": {
|
||||
"episode_data_index_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_to": {
|
||||
"episode_data_index_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"observation.images.top": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "43f176a3740fe622",
|
||||
"_fingerprint": "93e03c6320c7d56e",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
|
|
Binary file not shown.
|
@ -45,11 +45,11 @@
|
|||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_from": {
|
||||
"episode_data_index_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_to": {
|
||||
"episode_data_index_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "f7ed966ae18000ae",
|
||||
"_fingerprint": "21bb9a76ed78a475",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
|
|
Binary file not shown.
|
@ -41,11 +41,11 @@
|
|||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_from": {
|
||||
"episode_data_index_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_to": {
|
||||
"episode_data_index_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "7dcd82fc3815bba6",
|
||||
"_fingerprint": "a95cbec45e3bb9d6",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
|
|
|
@ -95,12 +95,14 @@ def test_compute_stats():
|
|||
"""
|
||||
from lerobot.common.datasets.xarm import XarmDataset
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
# get transform to convert images from uint8 [0,255] to float32 [0,1]
|
||||
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
|
||||
|
||||
dataset = XarmDataset(
|
||||
dataset_id="xarm_lift_medium",
|
||||
root=DATA_DIR,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
|
@ -115,11 +117,11 @@ def test_compute_stats():
|
|||
# get all frames from the dataset in the same dtype and range as during compute_stats
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=16,
|
||||
num_workers=8,
|
||||
batch_size=len(dataset),
|
||||
shuffle=False,
|
||||
)
|
||||
data_dict = next(iter(dataloader)) # takes 23 seconds
|
||||
data_dict = next(iter(dataloader))
|
||||
|
||||
# compute stats based on all frames from the dataset without any batching
|
||||
expected_stats = {}
|
||||
|
@ -154,8 +156,8 @@ def test_load_previous_and_future_frames_within_tolerance():
|
|||
data_dict = Dataset.from_dict({
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_data_id_from": [0, 0, 0, 0, 0],
|
||||
"episode_data_id_to": [4, 4, 4, 4, 4],
|
||||
"episode_data_index_from": [0, 0, 0, 0, 0],
|
||||
"episode_data_index_to": [4, 4, 4, 4, 4],
|
||||
})
|
||||
data_dict = data_dict.with_format("torch")
|
||||
item = data_dict[2]
|
||||
|
@ -170,8 +172,8 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
|
|||
data_dict = Dataset.from_dict({
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_data_id_from": [0, 0, 0, 0, 0],
|
||||
"episode_data_id_to": [4, 4, 4, 4, 4],
|
||||
"episode_data_index_from": [0, 0, 0, 0, 0],
|
||||
"episode_data_index_to": [4, 4, 4, 4, 4],
|
||||
})
|
||||
data_dict = data_dict.with_format("torch")
|
||||
item = data_dict[2]
|
||||
|
@ -184,8 +186,8 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
|
|||
data_dict = Dataset.from_dict({
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_data_id_from": [0, 0, 0, 0, 0],
|
||||
"episode_data_id_to": [4, 4, 4, 4, 4],
|
||||
"episode_data_index_from": [0, 0, 0, 0, 0],
|
||||
"episode_data_index_to": [4, 4, 4, 4, 4],
|
||||
})
|
||||
data_dict = data_dict.with_format("torch")
|
||||
item = data_dict[2]
|
||||
|
|
Loading…
Reference in New Issue