Address comments

This commit is contained in:
Cadene 2024-04-16 17:14:40 +00:00
parent b241ea46dd
commit 36d9e885ef
24 changed files with 100 additions and 94 deletions

View File

@ -17,6 +17,17 @@ from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage
def download_and_upload(root, root_tests, dataset_id):
if "pusht" in dataset_id:
download_and_upload_pusht(root, root_tests, dataset_id)
elif "xarm" in dataset_id:
download_and_upload_xarm(root, root_tests, dataset_id)
elif "aloha" in dataset_id:
download_and_upload_aloha(root, root_tests, dataset_id)
else:
raise ValueError(dataset_id)
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
import zipfile
@ -87,7 +98,6 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
states = torch.from_numpy(dataset_dict["state"])
actions = torch.from_numpy(dataset_dict["action"])
data_ids_per_episode = {}
ep_dicts = []
id_from = 0
@ -150,15 +160,11 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
"next.done": torch.cat([done[1:], done[[-1]]]),
"next.success": torch.cat([success[1:], success[[-1]]]),
"episode_data_id_from": torch.tensor([id_from] * num_frames),
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
"episode_data_index_from": torch.tensor([id_from] * num_frames),
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
}
ep_dicts.append(ep_dict)
assert isinstance(episode_id, int)
data_ids_per_episode[episode_id] = torch.arange(id_from, id_to, 1)
assert len(data_ids_per_episode[episode_id]) == num_frames
id_from += num_frames
data_dict = {}
@ -190,8 +196,8 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
"next.done": Value(dtype="bool", id=None),
"next.success": Value(dtype="bool", id=None),
"index": Value(dtype="int64", id=None),
"episode_data_id_from": Value(dtype="int64", id=None),
"episode_data_id_to": Value(dtype="int64", id=None),
"episode_data_index_from": Value(dtype="int64", id=None),
"episode_data_index_to": Value(dtype="int64", id=None),
}
features = Features(features)
dataset = Dataset.from_dict(data_dict, features=features)
@ -265,8 +271,8 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
# "next.observation.state": next_state,
"next.reward": next_reward,
"next.done": next_done,
"episode_data_id_from": torch.tensor([id_from] * num_frames),
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
"episode_data_index_from": torch.tensor([id_from] * num_frames),
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
}
ep_dicts.append(ep_dict)
@ -301,8 +307,8 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
"episode_data_id_from": Value(dtype="int64", id=None),
"episode_data_id_to": Value(dtype="int64", id=None),
"episode_data_index_from": Value(dtype="int64", id=None),
"episode_data_index_to": Value(dtype="int64", id=None),
}
features = Features(features)
dataset = Dataset.from_dict(data_dict, features=features)
@ -390,20 +396,7 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
ep_dict = {
"observation.state": state,
"action": action,
"episode_id": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.state": state,
# TODO(rcadene): compute reward and success
# "next.reward": reward,
"next.done": done,
# "next.success": success,
"episode_data_id_from": torch.tensor([id_from] * num_frames),
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
}
ep_dict = {}
for cam in cameras[dataset_id]:
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c
@ -411,6 +404,23 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image]
# ep_dict[f"next.observation.images.{cam}"] = image
ep_dict.update(
{
"observation.state": state,
"action": action,
"episode_id": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.state": state,
# TODO(rcadene): compute reward and success
# "next.reward": reward,
"next.done": done,
# "next.success": success,
"episode_data_index_from": torch.tensor([id_from] * num_frames),
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
}
)
assert isinstance(ep_id, int)
ep_dicts.append(ep_dict)
@ -446,8 +456,8 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
"episode_data_id_from": Value(dtype="int64", id=None),
"episode_data_id_to": Value(dtype="int64", id=None),
"episode_data_index_from": Value(dtype="int64", id=None),
"episode_data_index_to": Value(dtype="int64", id=None),
}
features = Features(features)
dataset = Dataset.from_dict(data_dict, features=features)
@ -461,23 +471,17 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
if __name__ == "__main__":
root = "data"
root_tests = "{root_tests}"
download_and_upload_pusht(root, root_tests, dataset_id="pusht")
download_and_upload_xarm(root, root_tests, dataset_id="xarm_lift_medium")
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_human")
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_scripted")
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_human")
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_scripted")
root_tests = "tests/data"
dataset_ids = [
"pusht",
"xarm_lift_medium",
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
# "pusht",
# "xarm_lift_medium",
# "aloha_sim_insertion_human",
# "aloha_sim_insertion_scripted",
# "aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
]
for dataset_id in dataset_ids:
download_and_upload(root, root_tests, dataset_id)
# assume stats have been precomputed
shutil.copy(f"{root}/{dataset_id}/stats.pth", f"{root_tests}/{dataset_id}/stats.pth")

View File

@ -7,9 +7,9 @@ import tqdm
def load_previous_and_future_frames(
item: dict[torch.Tensor],
data_dict: dict[torch.Tensor],
delta_timestamps: dict[list[float]],
item: dict[str, torch.Tensor],
data_dict: dict[str, torch.Tensor],
delta_timestamps: dict[str, list[float]],
tol: float = 0.04,
) -> dict[torch.Tensor]:
"""
@ -35,12 +35,12 @@ def load_previous_and_future_frames(
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection.
"""
# get indices of the frames associated to the episode, and their timestamps
ep_data_id_from = item["episode_data_id_from"].item()
ep_data_id_to = item["episode_data_id_to"].item()
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to + 1, 1)
ep_data_id_from = item["episode_data_index_from"].item()
ep_data_id_to = item["episode_data_index_to"].item()
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
# load timestamps
ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from : ep_data_id_to + 1]["timestamp"]
ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
# we make the assumption that the timestamps are sorted
ep_first_ts = ep_timestamps[0]

View File

@ -215,8 +215,8 @@ def eval_policy(
"timestamp": torch.arange(0, num_frames, 1) / fps,
"next.done": dones[ep_id, :num_frames],
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
"episode_data_id_from": torch.tensor([idx_from] * num_frames),
"episode_data_id_to": torch.tensor([idx_from + num_frames - 1] * num_frames),
"episode_data_index_from": torch.tensor([idx_from] * num_frames),
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
}
for key in observations:
ep_dict[key] = observations[key][ep_id][:num_frames]

View File

@ -141,15 +141,15 @@ def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_
online_dataset.data_dict = data_dict
else:
# find episode index and data frame indices according to previous episode in online_dataset
start_episode = online_dataset.data_dict["episode_id"][-1].item() + 1
start_index = online_dataset.data_dict["index"][-1].item() + 1
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
def shift_indices(example):
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
example["episode_id"] += start_episode
example["index"] += start_index
example["episode_data_id_from"] += start_index
example["episode_data_id_to"] += start_index
example["episode_data_index_from"] += start_index
example["episode_data_index_to"] += start_index
return example
disable_progress_bar() # map has a tqdm progress bar

View File

@ -77,7 +77,7 @@ def render_dataset(dataset, out_dir, max_num_episodes):
# add current frame to list of frames to render
frames[im_key].append(item[im_key])
end_of_episode = item["index"].item() == item["episode_data_id_to"].item()
end_of_episode = item["index"].item() == item["episode_data_index_to"].item()
out_dir.mkdir(parents=True, exist_ok=True)
for im_key in dataset.image_keys:

View File

@ -2,6 +2,9 @@
"citation": "",
"description": "",
"features": {
"observation.images.top": {
"_type": "Image"
},
"observation.state": {
"feature": {
"dtype": "float32",
@ -34,17 +37,14 @@
"dtype": "bool",
"_type": "Value"
},
"episode_data_id_from": {
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_id_to": {
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},
"observation.images.top": {
"_type": "Image"
},
"index": {
"dtype": "int64",
"_type": "Value"

View File

@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "05980bca35112ebd",
"_fingerprint": "d79cf82ffc86f110",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",

View File

@ -2,6 +2,9 @@
"citation": "",
"description": "",
"features": {
"observation.images.top": {
"_type": "Image"
},
"observation.state": {
"feature": {
"dtype": "float32",
@ -34,17 +37,14 @@
"dtype": "bool",
"_type": "Value"
},
"episode_data_id_from": {
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_id_to": {
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},
"observation.images.top": {
"_type": "Image"
},
"index": {
"dtype": "int64",
"_type": "Value"

View File

@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "f3330a7e1d8bc55b",
"_fingerprint": "d8e4a817b5449498",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",

View File

@ -2,6 +2,9 @@
"citation": "",
"description": "",
"features": {
"observation.images.top": {
"_type": "Image"
},
"observation.state": {
"feature": {
"dtype": "float32",
@ -34,17 +37,14 @@
"dtype": "bool",
"_type": "Value"
},
"episode_data_id_from": {
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_id_to": {
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},
"observation.images.top": {
"_type": "Image"
},
"index": {
"dtype": "int64",
"_type": "Value"

View File

@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "42aa77ffb6863924",
"_fingerprint": "f03482befa767127",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",

View File

@ -2,6 +2,9 @@
"citation": "",
"description": "",
"features": {
"observation.images.top": {
"_type": "Image"
},
"observation.state": {
"feature": {
"dtype": "float32",
@ -34,17 +37,14 @@
"dtype": "bool",
"_type": "Value"
},
"episode_data_id_from": {
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_id_to": {
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},
"observation.images.top": {
"_type": "Image"
},
"index": {
"dtype": "int64",
"_type": "Value"

View File

@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "43f176a3740fe622",
"_fingerprint": "93e03c6320c7d56e",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",

View File

@ -45,11 +45,11 @@
"dtype": "bool",
"_type": "Value"
},
"episode_data_id_from": {
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_id_to": {
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},

View File

@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "f7ed966ae18000ae",
"_fingerprint": "21bb9a76ed78a475",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",

View File

@ -41,11 +41,11 @@
"dtype": "bool",
"_type": "Value"
},
"episode_data_id_from": {
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_id_to": {
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},

View File

@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "7dcd82fc3815bba6",
"_fingerprint": "a95cbec45e3bb9d6",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",

View File

@ -95,12 +95,14 @@ def test_compute_stats():
"""
from lerobot.common.datasets.xarm import XarmDataset
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
# get transform to convert images from uint8 [0,255] to float32 [0,1]
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
dataset = XarmDataset(
dataset_id="xarm_lift_medium",
root=DATA_DIR,
transform=transform,
)
@ -115,11 +117,11 @@ def test_compute_stats():
# get all frames from the dataset in the same dtype and range as during compute_stats
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=16,
num_workers=8,
batch_size=len(dataset),
shuffle=False,
)
data_dict = next(iter(dataloader)) # takes 23 seconds
data_dict = next(iter(dataloader))
# compute stats based on all frames from the dataset without any batching
expected_stats = {}
@ -154,8 +156,8 @@ def test_load_previous_and_future_frames_within_tolerance():
data_dict = Dataset.from_dict({
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_id_from": [0, 0, 0, 0, 0],
"episode_data_id_to": [4, 4, 4, 4, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [4, 4, 4, 4, 4],
})
data_dict = data_dict.with_format("torch")
item = data_dict[2]
@ -170,8 +172,8 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
data_dict = Dataset.from_dict({
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_id_from": [0, 0, 0, 0, 0],
"episode_data_id_to": [4, 4, 4, 4, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [4, 4, 4, 4, 4],
})
data_dict = data_dict.with_format("torch")
item = data_dict[2]
@ -184,8 +186,8 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
data_dict = Dataset.from_dict({
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_id_from": [0, 0, 0, 0, 0],
"episode_data_id_to": [4, 4, 4, 4, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [4, 4, 4, 4, 4],
})
data_dict = data_dict.with_format("torch")
item = data_dict[2]