Add dataset visualization with rerun.io (#131)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
parent
c015252e20
commit
19812ca470
|
@ -43,25 +43,27 @@ print(f"average number of frames per episode: {dataset.num_samples / dataset.num
|
|||
print(f"frames per second used during data collection: {dataset.fps=}")
|
||||
print(f"keys to access images from cameras: {dataset.image_keys=}")
|
||||
|
||||
# While the LeRobotDataset adds helpers for working within our library, we still expose the underling Hugging Face dataset.
|
||||
# It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
|
||||
# TODO(rcadene): remove this example of accessing hf_dataset
|
||||
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
|
||||
# Access frame indexes associated to first episode
|
||||
episode_index = 0
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
|
||||
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grab all the image frames.
|
||||
frames = [sample["observation.image"] for sample in dataset]
|
||||
# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working with the latter, like iterating through the dataset.
|
||||
# Here we grab all the image frames.
|
||||
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
|
||||
|
||||
# but frames are now float32 range [0,1] channel first (c,h,w) to follow pytorch convention,
|
||||
# to view them, we convert to uint8 range [0,255]
|
||||
# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention.
|
||||
# To visualize them, we convert to uint8 range [0,255]
|
||||
frames = [(frame * 255).type(torch.uint8) for frame in frames]
|
||||
# and to channel last (h,w,c)
|
||||
# and to channel last (h,w,c).
|
||||
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
||||
|
||||
# and finally save them to a mp4 video
|
||||
# Finally, we save the frames to a mp4 video for visualization.
|
||||
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
|
||||
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_5.mp4", frames, fps=dataset.fps)
|
||||
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
|
||||
|
||||
# For many machine learning applications we need to load histories of past observations, or trajectorys of future actions. Our datasets can load previous and future frames for each key/modality,
|
||||
# For many machine learning applications we need to load the history of past observations or trajectories of future actions.
|
||||
# Our datasets can load previous and future frames for each key/modality,
|
||||
# using timestamps differences with the current loaded frame. For instance:
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import subprocess
|
||||
|
@ -41,10 +40,8 @@ def run_video_benchmark(
|
|||
repo_id = cfg["repo_id"]
|
||||
|
||||
# TODO(rcadene): rewrite with hardcoding of original images and episodes
|
||||
dataset = LeRobotDataset(
|
||||
repo_id,
|
||||
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||
)
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
# Get fps
|
||||
fps = dataset.fps
|
||||
|
||||
|
|
|
@ -1,14 +1,10 @@
|
|||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
||||
def make_dataset(
|
||||
cfg,
|
||||
|
@ -31,7 +27,6 @@ def make_dataset(
|
|||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
root=DATA_DIR,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
|
@ -13,7 +14,8 @@ from lerobot.common.datasets.utils import (
|
|||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
||||
|
||||
CODEBASE_VERSION = "v1.2"
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
CODEBASE_VERSION = "v1.3"
|
||||
|
||||
|
||||
class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
@ -21,7 +23,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self,
|
||||
repo_id: str,
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = None,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
|
|
|
@ -110,7 +110,6 @@ def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
|
|||
repo_id=repo_id,
|
||||
revision=revision,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["*.json, *.safetensors"],
|
||||
)
|
||||
|
||||
|
||||
|
@ -160,7 +159,7 @@ def push_dataset_to_hub(
|
|||
if out_dir.exists():
|
||||
shutil.rmtree(out_dir)
|
||||
|
||||
if tests_out_dir.exists():
|
||||
if tests_out_dir.exists() and save_tests_to_disk:
|
||||
shutil.rmtree(tests_out_dir)
|
||||
|
||||
if not raw_dir.exists():
|
||||
|
|
|
@ -1,116 +1,245 @@
|
|||
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
||||
|
||||
Note: The last frame of the episode doesnt always correspond to a final state.
|
||||
That's because our datasets are composed of transition from state to state up to
|
||||
the antepenultimate state associated to the ultimate action to arrive in the final state.
|
||||
However, there might not be a transition from a final state to another state.
|
||||
|
||||
Note: This script aims to visualize the data used to train the neural networks.
|
||||
~What you see is what you get~. When visualizing image modality, it is often expected to observe
|
||||
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
|
||||
save disk space. The compression factor applied has been tuned to not affect success rate.
|
||||
|
||||
Examples:
|
||||
|
||||
- Visualize data stored on a local machine:
|
||||
```
|
||||
local$ python lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
- Visualize data stored on a distant machine with a local viewer:
|
||||
```
|
||||
distant$ python lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0 \
|
||||
--save 1 \
|
||||
--output-dir path/to/directory
|
||||
|
||||
local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
|
||||
local$ rerun lerobot_pusht_episode_0.rrd
|
||||
```
|
||||
|
||||
- Visualize data stored on a distant machine through streaming:
|
||||
(You need to forward the websocket port to the distant machine, with
|
||||
`ssh -L 9087:localhost:9087 username@remote-host`)
|
||||
```
|
||||
distant$ python lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0 \
|
||||
--mode distant \
|
||||
--ws-port 9087
|
||||
|
||||
local$ rerun ws://localhost:9087
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import hydra
|
||||
import imageio
|
||||
import rerun as rr
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
NUM_EPISODES_TO_RENDER = 50
|
||||
MAX_NUM_STEPS = 1000
|
||||
FIRST_FRAME = 0
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
def visualize_dataset_cli(cfg: dict):
|
||||
visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, episode_index):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.frame_ids)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frame_ids)
|
||||
|
||||
|
||||
def cat_and_write_video(video_path, frames, fps):
|
||||
frames = torch.cat(frames)
|
||||
|
||||
# Expects images in [0, 1].
|
||||
frame = frames[0]
|
||||
if frame.ndim == 4:
|
||||
raise NotImplementedError("We currently dont support multiple timestamps.")
|
||||
c, h, w = frame.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {frame.shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert frame.dtype == torch.float32, f"expect torch.float32, but instead {frame.dtype=}"
|
||||
assert frame.max() <= 1, f"expect pixels lower than 1, but instead {frame.max()=}"
|
||||
assert frame.min() >= 0, f"expect pixels greater than 1, but instead {frame.min()=}"
|
||||
|
||||
# convert to channel last uint8 [0, 255]
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c")
|
||||
frames = (frames * 255).type(torch.uint8)
|
||||
imageio.mimsave(video_path, frames.numpy(), fps=fps)
|
||||
def to_hwc_uint8_numpy(chw_float32_torch):
|
||||
assert chw_float32_torch.dtype == torch.float32
|
||||
assert chw_float32_torch.ndim == 3
|
||||
c, h, w = chw_float32_torch.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
|
||||
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
|
||||
return hwc_uint8_numpy
|
||||
|
||||
|
||||
def visualize_dataset(cfg: dict, out_dir=None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
def visualize_dataset(
|
||||
repo_id: str,
|
||||
episode_index: int,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 0,
|
||||
mode: str = "local",
|
||||
web_port: int = 9090,
|
||||
ws_port: int = 9087,
|
||||
save: bool = False,
|
||||
output_dir: Path | None = None,
|
||||
) -> Path | None:
|
||||
if save:
|
||||
assert (
|
||||
output_dir is not None
|
||||
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
|
||||
|
||||
init_logging()
|
||||
log_output_dir(out_dir)
|
||||
|
||||
logging.info("make_dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
logging.info("Start rendering episodes from offline buffer")
|
||||
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER)
|
||||
for video_path in video_paths:
|
||||
logging.info(video_path)
|
||||
return video_paths
|
||||
|
||||
|
||||
def render_dataset(dataset, out_dir, max_num_episodes):
|
||||
out_dir = Path(out_dir)
|
||||
video_paths = []
|
||||
threads = []
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
logging.info("Loading dataloader")
|
||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
sampler=episode_sampler,
|
||||
)
|
||||
dl_iter = iter(dataloader)
|
||||
|
||||
for ep_id in range(min(max_num_episodes, dataset.num_episodes)):
|
||||
logging.info(f"Rendering episode {ep_id}")
|
||||
logging.info("Starting Rerun")
|
||||
|
||||
frames = {}
|
||||
end_of_episode = False
|
||||
while not end_of_episode:
|
||||
item = next(dl_iter)
|
||||
if mode not in ["local", "distant"]:
|
||||
raise ValueError(mode)
|
||||
|
||||
for im_key in dataset.image_keys:
|
||||
# when first frame of episode, initialize frames dict
|
||||
if im_key not in frames:
|
||||
frames[im_key] = []
|
||||
# add current frame to list of frames to render
|
||||
frames[im_key].append(item[im_key])
|
||||
spawn_local_viewer = mode == "local" and not save
|
||||
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
|
||||
if mode == "distant":
|
||||
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
|
||||
|
||||
end_of_episode = item["index"].item() == dataset.episode_data_index["to"][ep_id] - 1
|
||||
logging.info("Logging to Rerun")
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
for im_key in dataset.image_keys:
|
||||
if len(dataset.image_keys) > 1:
|
||||
im_name = im_key.replace("observation.images.", "")
|
||||
video_path = out_dir / f"episode_{ep_id}_{im_name}.mp4"
|
||||
else:
|
||||
video_path = out_dir / f"episode_{ep_id}.mp4"
|
||||
video_paths.append(video_path)
|
||||
if num_workers > 0:
|
||||
# TODO(rcadene): fix data workers hanging when `rr.init` is called
|
||||
logging.warning("If data loader is hanging, try `--num-workers 0`.")
|
||||
|
||||
thread = threading.Thread(
|
||||
target=cat_and_write_video,
|
||||
args=(str(video_path), frames[im_key], dataset.fps),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||
# iterate over the batch
|
||||
for i in range(len(batch["index"])):
|
||||
rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
|
||||
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# display each camera image
|
||||
for key in dataset.image_keys:
|
||||
# TODO(rcadene): add `.compress()`? is it lossless?
|
||||
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
|
||||
|
||||
logging.info("End of visualize_dataset")
|
||||
return video_paths
|
||||
# display each dimension of action space (e.g. actuators command)
|
||||
if "action" in batch:
|
||||
for dim_idx, val in enumerate(batch["action"][i]):
|
||||
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
|
||||
|
||||
# display each dimension of observed state space (e.g. agent position in joint space)
|
||||
if "observation.state" in batch:
|
||||
for dim_idx, val in enumerate(batch["observation.state"][i]):
|
||||
rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
|
||||
|
||||
if "next.done" in batch:
|
||||
rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
|
||||
|
||||
if "next.reward" in batch:
|
||||
rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
|
||||
|
||||
if "next.success" in batch:
|
||||
rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
|
||||
|
||||
if mode == "local" and save:
|
||||
# save .rrd locally
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
repo_id_str = repo_id.replace("/", "_")
|
||||
rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
|
||||
rr.save(rrd_path)
|
||||
return rrd_path
|
||||
|
||||
elif mode == "distant":
|
||||
# stop the process from exiting since it is serving the websocket connection
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("Ctrl-C received. Exiting.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episode-index",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Episode to visualize.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size loaded by DataLoader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of processes of Dataloader for loading the data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="local",
|
||||
help=(
|
||||
"Mode of viewing between 'local' or 'distant'. "
|
||||
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
|
||||
"'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--web-port",
|
||||
type=int,
|
||||
default=9090,
|
||||
help="Web port for rerun.io when `--mode distant` is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ws-port",
|
||||
type=int,
|
||||
default=9087,
|
||||
help="Web socket port for rerun.io when `--mode distant` is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Save a .rrd file in the directory provided by `--output-dir`. "
|
||||
"It also deactivates the spawning of a viewer. ",
|
||||
"Visualize the data by running `rerun path/to/file.rrd` on your local machine.",
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
help="Directory path to write a .rrd file when `--save 1` is set.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
visualize_dataset(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
visualize_dataset_cli()
|
||||
main()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
|
@ -3066,6 +3066,7 @@ files = [
|
|||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
|
@ -3210,6 +3211,30 @@ urllib3 = ">=1.21.1,<3"
|
|||
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
||||
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||
|
||||
[[package]]
|
||||
name = "rerun-sdk"
|
||||
version = "0.15.1"
|
||||
description = "The Rerun Logging SDK"
|
||||
optional = false
|
||||
python-versions = "<3.13,>=3.8"
|
||||
files = [
|
||||
{file = "rerun_sdk-0.15.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:be8f4e55c53bd9734bd0b8e91a9765daeb55e56caddc1bacdb358d12121daaa0"},
|
||||
{file = "rerun_sdk-0.15.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:e039ed441b6dcd5939e20f0f67fef4ffd54645777574822f48cd6f636efa3756"},
|
||||
{file = "rerun_sdk-0.15.1-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:5c067ba1c3304a0bb74bd33df8f7145ce7d405c823bfc8709396bbdd672a759e"},
|
||||
{file = "rerun_sdk-0.15.1-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:e8a96fff6e0c184a59b433430f5f87c96895e4b69dc0e43abb56a0e0737edc35"},
|
||||
{file = "rerun_sdk-0.15.1-cp38-abi3-win_amd64.whl", hash = "sha256:377a888e0cbe06835f376cd160ab322e9935ebd1317384381856236bd4347950"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
attrs = ">=23.1.0"
|
||||
numpy = ">=1.23,<2"
|
||||
pillow = "*"
|
||||
pyarrow = ">=14.0.2"
|
||||
typing-extensions = ">=4.5"
|
||||
|
||||
[package.extras]
|
||||
tests = ["pytest (==7.1.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "robomimic"
|
||||
version = "0.2.0"
|
||||
|
@ -4331,5 +4356,5 @@ xarm = ["gym-xarm"]
|
|||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "ba2d6275ad42f34f83193e8c64ef9dca301c6632c05523a564601d322ce7a31d"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "d2066576dc4aebaf623c295fe626bf6805fd2ec26a6ba47fa5415204994aa922"
|
||||
|
|
|
@ -27,7 +27,7 @@ packages = [{include = "lerobot"}]
|
|||
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
python = ">=3.10,<3.13"
|
||||
termcolor = "^2.4.0"
|
||||
omegaconf = "^2.3.0"
|
||||
wandb = "^0.16.3"
|
||||
|
@ -58,6 +58,7 @@ datasets = "^2.19.0"
|
|||
imagecodecs = { version = "^2024.1.1", optional = true }
|
||||
pyav = "^12.0.5"
|
||||
moviepy = "^1.0.3"
|
||||
rerun-sdk = "^0.15.1"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
|
|
@ -11,7 +11,6 @@ Example usage:
|
|||
`python tests/scripts/save_dataset_to_safetensors.py`
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -29,7 +28,7 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
|||
|
||||
repo_dir.mkdir(parents=True, exist_ok=True)
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=repo_id, root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
repo_id=repo_id,
|
||||
)
|
||||
|
||||
# save 2 first frames of first episode
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -97,9 +96,7 @@ def test_compute_stats_on_xarm():
|
|||
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
|
||||
because we are working with a small dataset).
|
||||
"""
|
||||
dataset = LeRobotDataset(
|
||||
"lerobot/xarm_lift_medium", root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
)
|
||||
dataset = LeRobotDataset("lerobot/xarm_lift_medium")
|
||||
|
||||
# reduce size of dataset sample on which stats compute is tested to 10 frames
|
||||
dataset.hf_dataset = dataset.hf_dataset.select(range(10))
|
||||
|
@ -254,7 +251,6 @@ def test_backward_compatibility(repo_id):
|
|||
|
||||
dataset = LeRobotDataset(
|
||||
repo_id,
|
||||
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||
)
|
||||
|
||||
test_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
|
||||
|
|
|
@ -18,7 +18,7 @@ def _run_script(path):
|
|||
def test_example_1():
|
||||
path = "examples/1_load_lerobot_dataset.py"
|
||||
_run_script(path)
|
||||
assert Path("outputs/examples/1_load_lerobot_dataset/episode_5.mp4").exists()
|
||||
assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists()
|
||||
|
||||
|
||||
def test_examples_3_and_2():
|
||||
|
|
|
@ -1,31 +1,18 @@
|
|||
import pytest
|
||||
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.scripts.visualize_dataset import visualize_dataset
|
||||
|
||||
from .utils import DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
[
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
],
|
||||
["lerobot/pusht"],
|
||||
)
|
||||
def test_visualize_dataset(tmpdir, repo_id):
|
||||
# TODO(rcadene): this test might fail with other datasets/policies/envs, since visualization_dataset
|
||||
# doesnt support multiple timesteps which requires delta_timestamps to None for images.
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
"policy=act",
|
||||
"env=aloha",
|
||||
f"dataset_repo_id={repo_id}",
|
||||
],
|
||||
rrd_path = visualize_dataset(
|
||||
repo_id,
|
||||
episode_index=0,
|
||||
batch_size=32,
|
||||
save=True,
|
||||
output_dir=tmpdir,
|
||||
)
|
||||
video_paths = visualize_dataset(cfg, out_dir=tmpdir)
|
||||
|
||||
assert len(video_paths) > 0
|
||||
|
||||
for video_path in video_paths:
|
||||
assert video_path.exists()
|
||||
assert rrd_path.exists()
|
||||
|
|
Loading…
Reference in New Issue