From 19812ca47075124c4cb46b51b9e1c1b54d0e69d4 Mon Sep 17 00:00:00 2001 From: Remi Date: Sat, 4 May 2024 16:07:14 +0200 Subject: [PATCH] Add dataset visualization with rerun.io (#131) Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> --- examples/1_load_lerobot_dataset.py | 26 +- .../_video_benchmark/run_video_benchmark.py | 7 +- lerobot/common/datasets/factory.py | 5 - lerobot/common/datasets/lerobot_dataset.py | 6 +- lerobot/scripts/push_dataset_to_hub.py | 3 +- lerobot/scripts/visualize_dataset.py | 307 +++++++++++++----- poetry.lock | 31 +- pyproject.toml | 3 +- tests/scripts/save_dataset_to_safetensors.py | 3 +- tests/test_datasets.py | 6 +- tests/test_examples.py | 2 +- tests/test_visualize_dataset.py | 29 +- 12 files changed, 280 insertions(+), 148 deletions(-) diff --git a/examples/1_load_lerobot_dataset.py b/examples/1_load_lerobot_dataset.py index e7b3c216..f86199c5 100644 --- a/examples/1_load_lerobot_dataset.py +++ b/examples/1_load_lerobot_dataset.py @@ -43,25 +43,27 @@ print(f"average number of frames per episode: {dataset.num_samples / dataset.num print(f"frames per second used during data collection: {dataset.fps=}") print(f"keys to access images from cameras: {dataset.image_keys=}") -# While the LeRobotDataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. -# It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5. -# TODO(rcadene): remove this example of accessing hf_dataset -dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5) +# Access frame indexes associated to first episode +episode_index = 0 +from_idx = dataset.episode_data_index["from"][episode_index].item() +to_idx = dataset.episode_data_index["to"][episode_index].item() -# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grab all the image frames. -frames = [sample["observation.image"] for sample in dataset] +# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working with the latter, like iterating through the dataset. +# Here we grab all the image frames. +frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)] -# but frames are now float32 range [0,1] channel first (c,h,w) to follow pytorch convention, -# to view them, we convert to uint8 range [0,255] +# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. +# To visualize them, we convert to uint8 range [0,255] frames = [(frame * 255).type(torch.uint8) for frame in frames] -# and to channel last (h,w,c) +# and to channel last (h,w,c). frames = [frame.permute((1, 2, 0)).numpy() for frame in frames] -# and finally save them to a mp4 video +# Finally, we save the frames to a mp4 video for visualization. Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True) -imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_5.mp4", frames, fps=dataset.fps) +imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps) -# For many machine learning applications we need to load histories of past observations, or trajectorys of future actions. Our datasets can load previous and future frames for each key/modality, +# For many machine learning applications we need to load the history of past observations or trajectories of future actions. +# Our datasets can load previous and future frames for each key/modality, # using timestamps differences with the current loaded frame. For instance: delta_timestamps = { # loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame diff --git a/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py b/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py index b6e83a0c..85d48fcf 100644 --- a/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py +++ b/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py @@ -1,5 +1,4 @@ import json -import os import random import shutil import subprocess @@ -41,10 +40,8 @@ def run_video_benchmark( repo_id = cfg["repo_id"] # TODO(rcadene): rewrite with hardcoding of original images and episodes - dataset = LeRobotDataset( - repo_id, - root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None, - ) + dataset = LeRobotDataset(repo_id) + # Get fps fps = dataset.fps diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index c9711ca3..22dd1789 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,14 +1,10 @@ import logging -import os -from pathlib import Path import torch from omegaconf import OmegaConf from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None - def make_dataset( cfg, @@ -31,7 +27,6 @@ def make_dataset( dataset = LeRobotDataset( cfg.dataset_repo_id, split=split, - root=DATA_DIR, delta_timestamps=delta_timestamps, ) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 186c3e48..c8cfbd8e 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -1,3 +1,4 @@ +import os from pathlib import Path import datasets @@ -13,7 +14,8 @@ from lerobot.common.datasets.utils import ( ) from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos -CODEBASE_VERSION = "v1.2" +DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None +CODEBASE_VERSION = "v1.3" class LeRobotDataset(torch.utils.data.Dataset): @@ -21,7 +23,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self, repo_id: str, version: str | None = CODEBASE_VERSION, - root: Path | None = None, + root: Path | None = DATA_DIR, split: str = "train", transform: callable = None, delta_timestamps: dict[list[float]] | None = None, diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index f5c9c749..ca8c4600 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -110,7 +110,6 @@ def push_meta_data_to_hub(repo_id, meta_data_dir, revision): repo_id=repo_id, revision=revision, repo_type="dataset", - allow_patterns=["*.json, *.safetensors"], ) @@ -160,7 +159,7 @@ def push_dataset_to_hub( if out_dir.exists(): shutil.rmtree(out_dir) - if tests_out_dir.exists(): + if tests_out_dir.exists() and save_tests_to_disk: shutil.rmtree(tests_out_dir) if not raw_dir.exists(): diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index a5be5e3f..44acd416 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -1,116 +1,245 @@ +""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset. + +Note: The last frame of the episode doesnt always correspond to a final state. +That's because our datasets are composed of transition from state to state up to +the antepenultimate state associated to the ultimate action to arrive in the final state. +However, there might not be a transition from a final state to another state. + +Note: This script aims to visualize the data used to train the neural networks. +~What you see is what you get~. When visualizing image modality, it is often expected to observe +lossly compression artifacts since these images have been decoded from compressed mp4 videos to +save disk space. The compression factor applied has been tuned to not affect success rate. + +Examples: + +- Visualize data stored on a local machine: +``` +local$ python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht \ + --episode-index 0 +``` + +- Visualize data stored on a distant machine with a local viewer: +``` +distant$ python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht \ + --episode-index 0 \ + --save 1 \ + --output-dir path/to/directory + +local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd . +local$ rerun lerobot_pusht_episode_0.rrd +``` + +- Visualize data stored on a distant machine through streaming: +(You need to forward the websocket port to the distant machine, with +`ssh -L 9087:localhost:9087 username@remote-host`) +``` +distant$ python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht \ + --episode-index 0 \ + --mode distant \ + --ws-port 9087 + +local$ rerun ws://localhost:9087 +``` + +""" + +import argparse import logging -import threading +import time from pathlib import Path -import einops -import hydra -import imageio +import rerun as rr import torch +import tqdm -from lerobot.common.datasets.factory import make_dataset -from lerobot.common.logger import log_output_dir -from lerobot.common.utils.utils import init_logging - -NUM_EPISODES_TO_RENDER = 50 -MAX_NUM_STEPS = 1000 -FIRST_FRAME = 0 +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -@hydra.main(version_base="1.2", config_name="default", config_path="../configs") -def visualize_dataset_cli(cfg: dict): - visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir) +class EpisodeSampler(torch.utils.data.Sampler): + def __init__(self, dataset, episode_index): + from_idx = dataset.episode_data_index["from"][episode_index].item() + to_idx = dataset.episode_data_index["to"][episode_index].item() + self.frame_ids = range(from_idx, to_idx) + + def __iter__(self): + return iter(self.frame_ids) + + def __len__(self): + return len(self.frame_ids) -def cat_and_write_video(video_path, frames, fps): - frames = torch.cat(frames) - - # Expects images in [0, 1]. - frame = frames[0] - if frame.ndim == 4: - raise NotImplementedError("We currently dont support multiple timestamps.") - c, h, w = frame.shape - assert c < h and c < w, f"expect channel first images, but instead {frame.shape}" - - # sanity check that images are float32 in range [0,1] - assert frame.dtype == torch.float32, f"expect torch.float32, but instead {frame.dtype=}" - assert frame.max() <= 1, f"expect pixels lower than 1, but instead {frame.max()=}" - assert frame.min() >= 0, f"expect pixels greater than 1, but instead {frame.min()=}" - - # convert to channel last uint8 [0, 255] - frames = einops.rearrange(frames, "b c h w -> b h w c") - frames = (frames * 255).type(torch.uint8) - imageio.mimsave(video_path, frames.numpy(), fps=fps) +def to_hwc_uint8_numpy(chw_float32_torch): + assert chw_float32_torch.dtype == torch.float32 + assert chw_float32_torch.ndim == 3 + c, h, w = chw_float32_torch.shape + assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}" + hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() + return hwc_uint8_numpy -def visualize_dataset(cfg: dict, out_dir=None): - if out_dir is None: - raise NotImplementedError() +def visualize_dataset( + repo_id: str, + episode_index: int, + batch_size: int = 32, + num_workers: int = 0, + mode: str = "local", + web_port: int = 9090, + ws_port: int = 9087, + save: bool = False, + output_dir: Path | None = None, +) -> Path | None: + if save: + assert ( + output_dir is not None + ), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." - init_logging() - log_output_dir(out_dir) - - logging.info("make_dataset") - dataset = make_dataset(cfg) - - logging.info("Start rendering episodes from offline buffer") - video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER) - for video_path in video_paths: - logging.info(video_path) - return video_paths - - -def render_dataset(dataset, out_dir, max_num_episodes): - out_dir = Path(out_dir) - video_paths = [] - threads = [] + logging.info("Loading dataset") + dataset = LeRobotDataset(repo_id) + logging.info("Loading dataloader") + episode_sampler = EpisodeSampler(dataset, episode_index) dataloader = torch.utils.data.DataLoader( dataset, - num_workers=4, - batch_size=1, - shuffle=False, + num_workers=num_workers, + batch_size=batch_size, + sampler=episode_sampler, ) - dl_iter = iter(dataloader) - for ep_id in range(min(max_num_episodes, dataset.num_episodes)): - logging.info(f"Rendering episode {ep_id}") + logging.info("Starting Rerun") - frames = {} - end_of_episode = False - while not end_of_episode: - item = next(dl_iter) + if mode not in ["local", "distant"]: + raise ValueError(mode) - for im_key in dataset.image_keys: - # when first frame of episode, initialize frames dict - if im_key not in frames: - frames[im_key] = [] - # add current frame to list of frames to render - frames[im_key].append(item[im_key]) + spawn_local_viewer = mode == "local" and not save + rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer) + if mode == "distant": + rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port) - end_of_episode = item["index"].item() == dataset.episode_data_index["to"][ep_id] - 1 + logging.info("Logging to Rerun") - out_dir.mkdir(parents=True, exist_ok=True) - for im_key in dataset.image_keys: - if len(dataset.image_keys) > 1: - im_name = im_key.replace("observation.images.", "") - video_path = out_dir / f"episode_{ep_id}_{im_name}.mp4" - else: - video_path = out_dir / f"episode_{ep_id}.mp4" - video_paths.append(video_path) + if num_workers > 0: + # TODO(rcadene): fix data workers hanging when `rr.init` is called + logging.warning("If data loader is hanging, try `--num-workers 0`.") - thread = threading.Thread( - target=cat_and_write_video, - args=(str(video_path), frames[im_key], dataset.fps), - ) - thread.start() - threads.append(thread) + for batch in tqdm.tqdm(dataloader, total=len(dataloader)): + # iterate over the batch + for i in range(len(batch["index"])): + rr.set_time_sequence("frame_index", batch["frame_index"][i].item()) + rr.set_time_seconds("timestamp", batch["timestamp"][i].item()) - for thread in threads: - thread.join() + # display each camera image + for key in dataset.image_keys: + # TODO(rcadene): add `.compress()`? is it lossless? + rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i]))) - logging.info("End of visualize_dataset") - return video_paths + # display each dimension of action space (e.g. actuators command) + if "action" in batch: + for dim_idx, val in enumerate(batch["action"][i]): + rr.log(f"action/{dim_idx}", rr.Scalar(val.item())) + + # display each dimension of observed state space (e.g. agent position in joint space) + if "observation.state" in batch: + for dim_idx, val in enumerate(batch["observation.state"][i]): + rr.log(f"state/{dim_idx}", rr.Scalar(val.item())) + + if "next.done" in batch: + rr.log("next.done", rr.Scalar(batch["next.done"][i].item())) + + if "next.reward" in batch: + rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item())) + + if "next.success" in batch: + rr.log("next.success", rr.Scalar(batch["next.success"][i].item())) + + if mode == "local" and save: + # save .rrd locally + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + repo_id_str = repo_id.replace("/", "_") + rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd" + rr.save(rrd_path) + return rrd_path + + elif mode == "distant": + # stop the process from exiting since it is serving the websocket connection + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print("Ctrl-C received. Exiting.") + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).", + ) + parser.add_argument( + "--episode-index", + type=int, + required=True, + help="Episode to visualize.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=32, + help="Batch size loaded by DataLoader.", + ) + parser.add_argument( + "--num-workers", + type=int, + default=0, + help="Number of processes of Dataloader for loading the data.", + ) + parser.add_argument( + "--mode", + type=str, + default="local", + help=( + "Mode of viewing between 'local' or 'distant'. " + "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. " + "'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." + ), + ) + parser.add_argument( + "--web-port", + type=int, + default=9090, + help="Web port for rerun.io when `--mode distant` is set.", + ) + parser.add_argument( + "--ws-port", + type=int, + default=9087, + help="Web socket port for rerun.io when `--mode distant` is set.", + ) + parser.add_argument( + "--save", + type=int, + default=0, + help=( + "Save a .rrd file in the directory provided by `--output-dir`. " + "It also deactivates the spawning of a viewer. ", + "Visualize the data by running `rerun path/to/file.rrd` on your local machine.", + ), + ) + parser.add_argument( + "--output-dir", + type=str, + help="Directory path to write a .rrd file when `--save 1` is set.", + ) + + args = parser.parse_args() + visualize_dataset(**vars(args)) if __name__ == "__main__": - visualize_dataset_cli() + main() diff --git a/poetry.lock b/poetry.lock index f8112958..1121e68c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -3066,6 +3066,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -3210,6 +3211,30 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rerun-sdk" +version = "0.15.1" +description = "The Rerun Logging SDK" +optional = false +python-versions = "<3.13,>=3.8" +files = [ + {file = "rerun_sdk-0.15.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:be8f4e55c53bd9734bd0b8e91a9765daeb55e56caddc1bacdb358d12121daaa0"}, + {file = "rerun_sdk-0.15.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:e039ed441b6dcd5939e20f0f67fef4ffd54645777574822f48cd6f636efa3756"}, + {file = "rerun_sdk-0.15.1-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:5c067ba1c3304a0bb74bd33df8f7145ce7d405c823bfc8709396bbdd672a759e"}, + {file = "rerun_sdk-0.15.1-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:e8a96fff6e0c184a59b433430f5f87c96895e4b69dc0e43abb56a0e0737edc35"}, + {file = "rerun_sdk-0.15.1-cp38-abi3-win_amd64.whl", hash = "sha256:377a888e0cbe06835f376cd160ab322e9935ebd1317384381856236bd4347950"}, +] + +[package.dependencies] +attrs = ">=23.1.0" +numpy = ">=1.23,<2" +pillow = "*" +pyarrow = ">=14.0.2" +typing-extensions = ">=4.5" + +[package.extras] +tests = ["pytest (==7.1.2)"] + [[package]] name = "robomimic" version = "0.2.0" @@ -4331,5 +4356,5 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" -python-versions = "^3.10" -content-hash = "ba2d6275ad42f34f83193e8c64ef9dca301c6632c05523a564601d322ce7a31d" +python-versions = ">=3.10,<3.13" +content-hash = "d2066576dc4aebaf623c295fe626bf6805fd2ec26a6ba47fa5415204994aa922" diff --git a/pyproject.toml b/pyproject.toml index 23116a7a..32428845 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ packages = [{include = "lerobot"}] [tool.poetry.dependencies] -python = "^3.10" +python = ">=3.10,<3.13" termcolor = "^2.4.0" omegaconf = "^2.3.0" wandb = "^0.16.3" @@ -58,6 +58,7 @@ datasets = "^2.19.0" imagecodecs = { version = "^2024.1.1", optional = true } pyav = "^12.0.5" moviepy = "^1.0.3" +rerun-sdk = "^0.15.1" [tool.poetry.extras] diff --git a/tests/scripts/save_dataset_to_safetensors.py b/tests/scripts/save_dataset_to_safetensors.py index b4b0f76a..17cf2b38 100644 --- a/tests/scripts/save_dataset_to_safetensors.py +++ b/tests/scripts/save_dataset_to_safetensors.py @@ -11,7 +11,6 @@ Example usage: `python tests/scripts/save_dataset_to_safetensors.py` """ -import os import shutil from pathlib import Path @@ -29,7 +28,7 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"): repo_dir.mkdir(parents=True, exist_ok=True) dataset = LeRobotDataset( - repo_id=repo_id, root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None + repo_id=repo_id, ) # save 2 first frames of first episode diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e4be423f..e50d4108 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,6 +1,5 @@ import json import logging -import os from copy import deepcopy from pathlib import Path @@ -97,9 +96,7 @@ def test_compute_stats_on_xarm(): We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do because we are working with a small dataset). """ - dataset = LeRobotDataset( - "lerobot/xarm_lift_medium", root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None - ) + dataset = LeRobotDataset("lerobot/xarm_lift_medium") # reduce size of dataset sample on which stats compute is tested to 10 frames dataset.hf_dataset = dataset.hf_dataset.select(range(10)) @@ -254,7 +251,6 @@ def test_backward_compatibility(repo_id): dataset = LeRobotDataset( repo_id, - root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None, ) test_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id diff --git a/tests/test_examples.py b/tests/test_examples.py index 1fca45f5..8ff808f3 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -18,7 +18,7 @@ def _run_script(path): def test_example_1(): path = "examples/1_load_lerobot_dataset.py" _run_script(path) - assert Path("outputs/examples/1_load_lerobot_dataset/episode_5.mp4").exists() + assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists() def test_examples_3_and_2(): diff --git a/tests/test_visualize_dataset.py b/tests/test_visualize_dataset.py index 895ff9d9..0124afd3 100644 --- a/tests/test_visualize_dataset.py +++ b/tests/test_visualize_dataset.py @@ -1,31 +1,18 @@ import pytest -from lerobot.common.utils.utils import init_hydra_config from lerobot.scripts.visualize_dataset import visualize_dataset -from .utils import DEFAULT_CONFIG_PATH - @pytest.mark.parametrize( "repo_id", - [ - "lerobot/aloha_sim_insertion_human", - ], + ["lerobot/pusht"], ) def test_visualize_dataset(tmpdir, repo_id): - # TODO(rcadene): this test might fail with other datasets/policies/envs, since visualization_dataset - # doesnt support multiple timesteps which requires delta_timestamps to None for images. - cfg = init_hydra_config( - DEFAULT_CONFIG_PATH, - overrides=[ - "policy=act", - "env=aloha", - f"dataset_repo_id={repo_id}", - ], + rrd_path = visualize_dataset( + repo_id, + episode_index=0, + batch_size=32, + save=True, + output_dir=tmpdir, ) - video_paths = visualize_dataset(cfg, out_dir=tmpdir) - - assert len(video_paths) > 0 - - for video_path in video_paths: - assert video_path.exists() + assert rrd_path.exists()