fix unit tests, stats was missing, visualize_dataset was broken

This commit is contained in:
Cadene 2024-04-16 12:53:31 +00:00
parent 69eeced9d9
commit 4a3eac4743
9 changed files with 37 additions and 16 deletions

View File

@ -53,7 +53,11 @@ def make_dataset(
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
elif stats_path is None:
# load stats if the file exists already or compute stats and save it
if DATA_DIR is None:
# TODO(rcadene): clean stats
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
else:
precomputed_stats_path = DATA_DIR / cfg.dataset_id / "stats.pth"
if precomputed_stats_path.exists():
stats = torch.load(precomputed_stats_path)
else:

View File

@ -5,6 +5,7 @@ useless dependencies when using datasets.
import io
import pickle
import shutil
from pathlib import Path
import einops
@ -44,7 +45,7 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
return False
def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
try:
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
@ -197,12 +198,12 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
dataset = dataset.with_format("torch")
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
def download_and_upload_xarm(root, dataset_id, fps=15):
def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
root = Path(root)
raw_dir = root / f"{dataset_id}_raw"
if not raw_dir.exists():
@ -308,12 +309,12 @@ def download_and_upload_xarm(root, dataset_id, fps=15):
dataset = dataset.with_format("torch")
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
def download_and_upload_aloha(root, dataset_id, fps=50):
def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
folder_urls = {
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
@ -453,16 +454,30 @@ def download_and_upload_aloha(root, dataset_id, fps=50):
dataset = dataset.with_format("torch")
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
if __name__ == "__main__":
root = "data"
download_and_upload_pusht(root, dataset_id="pusht")
download_and_upload_xarm(root, dataset_id="xarm_lift_medium")
download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_human")
download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_scripted")
download_and_upload_aloha(root, dataset_id="aloha_sim_transfer_cube_human")
download_and_upload_aloha(root, dataset_id="aloha_sim_transfer_cube_scripted")
root_tests = "{root_tests}"
download_and_upload_pusht(root, root_tests, dataset_id="pusht")
download_and_upload_xarm(root, root_tests, dataset_id="xarm_lift_medium")
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_human")
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_scripted")
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_human")
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_scripted")
dataset_ids = [
"pusht",
"xarm_lift_medium",
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
]
for dataset_id in dataset_ids:
# assume stats have been precomputed
shutil.copy(f"{root}/{dataset_id}/stats.pth", f"{root_tests}/{dataset_id}/stats.pth")

View File

@ -62,12 +62,12 @@ def render_dataset(dataset, out_dir, max_num_episodes):
)
dl_iter = iter(dataloader)
num_episodes = len(dataset.data_ids_per_episode)
for ep_id in range(min(max_num_episodes, num_episodes)):
for ep_id in range(min(max_num_episodes, dataset.num_episodes)):
logging.info(f"Rendering episode {ep_id}")
frames = {}
for _ in dataset.data_ids_per_episode[ep_id]:
end_of_episode = False
while not end_of_episode:
item = next(dl_iter)
for im_key in dataset.image_keys:
@ -77,6 +77,8 @@ def render_dataset(dataset, out_dir, max_num_episodes):
# add current frame to list of frames to render
frames[im_key].append(item[im_key])
end_of_episode = item["index"].item() == item["episode_data_id_to"].item()
out_dir.mkdir(parents=True, exist_ok=True)
for im_key in dataset.image_keys:
if len(dataset.image_keys) > 1:

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
tests/data/pusht/stats.pth Normal file

Binary file not shown.

Binary file not shown.