fix unit tests, stats was missing, visualize_dataset was broken
This commit is contained in:
parent
69eeced9d9
commit
4a3eac4743
|
@ -53,7 +53,11 @@ def make_dataset(
|
|||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
elif stats_path is None:
|
||||
# load stats if the file exists already or compute stats and save it
|
||||
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
|
||||
if DATA_DIR is None:
|
||||
# TODO(rcadene): clean stats
|
||||
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
|
||||
else:
|
||||
precomputed_stats_path = DATA_DIR / cfg.dataset_id / "stats.pth"
|
||||
if precomputed_stats_path.exists():
|
||||
stats = torch.load(precomputed_stats_path)
|
||||
else:
|
||||
|
|
|
@ -5,6 +5,7 @@ useless dependencies when using datasets.
|
|||
|
||||
import io
|
||||
import pickle
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
|
@ -44,7 +45,7 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
|
||||
def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
|
@ -197,12 +198,12 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
|
|||
dataset = dataset.with_format("torch")
|
||||
|
||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
||||
dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
|
||||
dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
|
||||
|
||||
def download_and_upload_xarm(root, dataset_id, fps=15):
|
||||
def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
||||
root = Path(root)
|
||||
raw_dir = root / f"{dataset_id}_raw"
|
||||
if not raw_dir.exists():
|
||||
|
@ -308,12 +309,12 @@ def download_and_upload_xarm(root, dataset_id, fps=15):
|
|||
dataset = dataset.with_format("torch")
|
||||
|
||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
||||
dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
|
||||
dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
|
||||
|
||||
def download_and_upload_aloha(root, dataset_id, fps=50):
|
||||
def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
||||
folder_urls = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
||||
|
@ -453,16 +454,30 @@ def download_and_upload_aloha(root, dataset_id, fps=50):
|
|||
dataset = dataset.with_format("torch")
|
||||
|
||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
||||
dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
|
||||
dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
root = "data"
|
||||
download_and_upload_pusht(root, dataset_id="pusht")
|
||||
download_and_upload_xarm(root, dataset_id="xarm_lift_medium")
|
||||
download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_human")
|
||||
download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_scripted")
|
||||
download_and_upload_aloha(root, dataset_id="aloha_sim_transfer_cube_human")
|
||||
download_and_upload_aloha(root, dataset_id="aloha_sim_transfer_cube_scripted")
|
||||
root_tests = "{root_tests}"
|
||||
|
||||
download_and_upload_pusht(root, root_tests, dataset_id="pusht")
|
||||
download_and_upload_xarm(root, root_tests, dataset_id="xarm_lift_medium")
|
||||
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_human")
|
||||
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_insertion_scripted")
|
||||
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_human")
|
||||
download_and_upload_aloha(root, root_tests, dataset_id="aloha_sim_transfer_cube_scripted")
|
||||
|
||||
dataset_ids = [
|
||||
"pusht",
|
||||
"xarm_lift_medium",
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
]
|
||||
for dataset_id in dataset_ids:
|
||||
# assume stats have been precomputed
|
||||
shutil.copy(f"{root}/{dataset_id}/stats.pth", f"{root_tests}/{dataset_id}/stats.pth")
|
||||
|
|
|
@ -62,12 +62,12 @@ def render_dataset(dataset, out_dir, max_num_episodes):
|
|||
)
|
||||
dl_iter = iter(dataloader)
|
||||
|
||||
num_episodes = len(dataset.data_ids_per_episode)
|
||||
for ep_id in range(min(max_num_episodes, num_episodes)):
|
||||
for ep_id in range(min(max_num_episodes, dataset.num_episodes)):
|
||||
logging.info(f"Rendering episode {ep_id}")
|
||||
|
||||
frames = {}
|
||||
for _ in dataset.data_ids_per_episode[ep_id]:
|
||||
end_of_episode = False
|
||||
while not end_of_episode:
|
||||
item = next(dl_iter)
|
||||
|
||||
for im_key in dataset.image_keys:
|
||||
|
@ -77,6 +77,8 @@ def render_dataset(dataset, out_dir, max_num_episodes):
|
|||
# add current frame to list of frames to render
|
||||
frames[im_key].append(item[im_key])
|
||||
|
||||
end_of_episode = item["index"].item() == item["episode_data_id_to"].item()
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
for im_key in dataset.image_keys:
|
||||
if len(dataset.image_keys) > 1:
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue