fix unit tests, stats was missing, visualize_dataset was broken
This commit is contained in:
parent
69eeced9d9
commit
4a3eac4743
|
@ -53,7 +53,11 @@ def make_dataset(
|
||||||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||||
elif stats_path is None:
|
elif stats_path is None:
|
||||||
# load stats if the file exists already or compute stats and save it
|
# load stats if the file exists already or compute stats and save it
|
||||||
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
|
if DATA_DIR is None:
|
||||||
|
# TODO(rcadene): clean stats
|
||||||
|
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
|
||||||
|
else:
|
||||||
|
precomputed_stats_path = DATA_DIR / cfg.dataset_id / "stats.pth"
|
||||||
if precomputed_stats_path.exists():
|
if precomputed_stats_path.exists():
|
||||||
stats = torch.load(precomputed_stats_path)
|
stats = torch.load(precomputed_stats_path)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -5,6 +5,7 @@ useless dependencies when using datasets.
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import pickle
|
import pickle
|
||||||
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
@ -44,7 +45,7 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
|
def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
||||||
try:
|
try:
|
||||||
import pymunk
|
import pymunk
|
||||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||||
|
@ -197,12 +198,12 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
|
||||||
dataset = dataset.with_format("torch")
|
dataset = dataset.with_format("torch")
|
||||||
|
|
||||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
||||||
dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
|
dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
||||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload_xarm(root, dataset_id, fps=15):
|
def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
||||||
root = Path(root)
|
root = Path(root)
|
||||||
raw_dir = root / f"{dataset_id}_raw"
|
raw_dir = root / f"{dataset_id}_raw"
|
||||||
if not raw_dir.exists():
|
if not raw_dir.exists():
|
||||||
|
@ -308,12 +309,12 @@ def download_and_upload_xarm(root, dataset_id, fps=15):
|
||||||
dataset = dataset.with_format("torch")
|
dataset = dataset.with_format("torch")
|
||||||
|
|
||||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
||||||
dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
|
dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
||||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload_aloha(root, dataset_id, fps=50):
|
def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
||||||
folder_urls = {
|
folder_urls = {
|
||||||
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
||||||
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
||||||
|
@ -453,16 +454,30 @@ def download_and_upload_aloha(root, dataset_id, fps=50):
|
||||||
dataset = dataset.with_format("torch")
|
dataset = dataset.with_format("torch")
|
||||||
|
|
||||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
||||||
dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
|
dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
||||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
root = "data"
|
root = "data"
|
||||||
download_and_upload_pusht(root, dataset_id="pusht")
|
root_tests = "{root_tests}"
|
||||||
download_and_upload_xarm(root, dataset_id="xarm_lift_medium")
|
|
||||||
download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_human")
|
download_and_upload_pusht(root, root_tests, dataset_id="pusht")
|
||||||
download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_scripted")
|
download_and_upload_xarm(root, root_tests, dataset_id="xarm_lift_medium")
|
||||||
download_and_upload_aloha(root, dataset_id="aloha_sim_transfer_cube_human")
|
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_human")
|
||||||
download_and_upload_aloha(root, dataset_id="aloha_sim_transfer_cube_scripted")
|
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_scripted")
|
||||||
|
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_human")
|
||||||
|
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_scripted")
|
||||||
|
|
||||||
|
dataset_ids = [
|
||||||
|
"pusht",
|
||||||
|
"xarm_lift_medium",
|
||||||
|
"aloha_sim_insertion_human",
|
||||||
|
"aloha_sim_insertion_scripted",
|
||||||
|
"aloha_sim_transfer_cube_human",
|
||||||
|
"aloha_sim_transfer_cube_scripted",
|
||||||
|
]
|
||||||
|
for dataset_id in dataset_ids:
|
||||||
|
# assume stats have been precomputed
|
||||||
|
shutil.copy(f"{root}/{dataset_id}/stats.pth", f"{root_tests}/{dataset_id}/stats.pth")
|
||||||
|
|
|
@ -62,12 +62,12 @@ def render_dataset(dataset, out_dir, max_num_episodes):
|
||||||
)
|
)
|
||||||
dl_iter = iter(dataloader)
|
dl_iter = iter(dataloader)
|
||||||
|
|
||||||
num_episodes = len(dataset.data_ids_per_episode)
|
for ep_id in range(min(max_num_episodes, dataset.num_episodes)):
|
||||||
for ep_id in range(min(max_num_episodes, num_episodes)):
|
|
||||||
logging.info(f"Rendering episode {ep_id}")
|
logging.info(f"Rendering episode {ep_id}")
|
||||||
|
|
||||||
frames = {}
|
frames = {}
|
||||||
for _ in dataset.data_ids_per_episode[ep_id]:
|
end_of_episode = False
|
||||||
|
while not end_of_episode:
|
||||||
item = next(dl_iter)
|
item = next(dl_iter)
|
||||||
|
|
||||||
for im_key in dataset.image_keys:
|
for im_key in dataset.image_keys:
|
||||||
|
@ -77,6 +77,8 @@ def render_dataset(dataset, out_dir, max_num_episodes):
|
||||||
# add current frame to list of frames to render
|
# add current frame to list of frames to render
|
||||||
frames[im_key].append(item[im_key])
|
frames[im_key].append(item[im_key])
|
||||||
|
|
||||||
|
end_of_episode = item["index"].item() == item["episode_data_id_to"].item()
|
||||||
|
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
for im_key in dataset.image_keys:
|
for im_key in dataset.image_keys:
|
||||||
if len(dataset.image_keys) > 1:
|
if len(dataset.image_keys) > 1:
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue