address comments

This commit is contained in:
Cadene 2024-04-22 14:48:53 +00:00
parent 110e5afb4b
commit db9ac59230
6 changed files with 57 additions and 8 deletions

View File

@ -34,6 +34,7 @@ hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="t
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_scripted", split="train"), 50 hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_scripted", split="train"), 50
``` ```
""" """
# TODO(rcadene): remove this example file of using hf_dataset
from pathlib import Path from pathlib import Path
@ -46,6 +47,7 @@ from datasets import load_dataset
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10 hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
# display name of dataset and its features # display name of dataset and its features
# TODO(rcadene): update to make the print pretty
print(f"{hf_dataset=}") print(f"{hf_dataset=}")
print(f"{hf_dataset.features=}") print(f"{hf_dataset.features=}")

View File

@ -47,6 +47,7 @@ from lerobot.common.datasets.pusht import PushtDataset
dataset = PushtDataset() dataset = PushtDataset()
# All LeRobot datasets are actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information). # All LeRobot datasets are actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
# TODO(rcadene): update to make the print pretty
print(f"{dataset=}") print(f"{dataset=}")
print(f"{dataset.hf_dataset=}") print(f"{dataset.hf_dataset=}")
@ -58,14 +59,16 @@ print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.image_keys=}") print(f"keys to access images from cameras: {dataset.image_keys=}")
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5. # While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
# TODO(rcadene): remove this example of accessing hf_dataset
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5) dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grab all the image frames. # LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grab all the image frames.
frames = [sample["observation.image"] for sample in dataset] frames = [sample["observation.image"] for sample in dataset]
# but frames are now float32 range [0,1] channel first to follow pytorch convention, # but frames are now float32 range [0,1] channel first (c,h,w) to follow pytorch convention,
# to view them, we convert to uint8 range [0,255] channel last # to view them, we convert to uint8 range [0,255]
frames = [(frame * 255).type(torch.uint8) for frame in frames] frames = [(frame * 255).type(torch.uint8) for frame in frames]
# and to channel last (h,w,c)
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames] frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
# and finally save them to a mp4 video # and finally save them to a mp4 video

View File

@ -12,8 +12,6 @@ from PIL import Image as PILImage
from safetensors.torch import load_file from safetensors.torch import load_file
from torchvision import transforms from torchvision import transforms
from lerobot.common.utils.utils import set_global_seed
def flatten_dict(d, parent_key="", sep="/"): def flatten_dict(d, parent_key="", sep="/"):
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator. """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
@ -255,13 +253,15 @@ def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
min[key] = torch.tensor(float("inf")).float() min[key] = torch.tensor(float("inf")).float()
def create_seeded_dataloader(hf_dataset, batch_size, seed): def create_seeded_dataloader(hf_dataset, batch_size, seed):
set_global_seed(seed) generator = torch.Generator()
generator.manual_seed(seed)
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
hf_dataset, hf_dataset,
num_workers=4, num_workers=4,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
generator=generator,
) )
return dataloader return dataloader

View File

@ -39,4 +39,14 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
for _ in range(num_parallel_envs) for _ in range(num_parallel_envs)
] ]
) )
def preprocess_observation(observation):
return observation / 255.0
def preprocess_reward(reward):
return reward / 255.0
env = gym.wrappers.TransformObservation(env, preprocess_observation)
env = gym.wrappers.TransformReward(env, preprocess_reward)
return env return env

View File

@ -26,7 +26,9 @@ def cat_and_write_video(video_path, frames, fps):
# Expects images in [0, 1]. # Expects images in [0, 1].
frame = frames[0] frame = frames[0]
_, c, h, w = frame.shape if frame.ndim == 4:
raise NotImplementedError("We currently dont support multiple timestamps.")
c, h, w = frame.shape
assert c < h and c < w, f"expect channel first images, but instead {frame.shape}" assert c < h and c < w, f"expect channel first images, but instead {frame.shape}"
# sanity check that images are float32 in range [0,1] # sanity check that images are float32 in range [0,1]
@ -55,9 +57,10 @@ def visualize_dataset(cfg: dict, out_dir=None):
) )
logging.info("Start rendering episodes from offline buffer") logging.info("Start rendering episodes from offline buffer")
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps) video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER)
for video_path in video_paths: for video_path in video_paths:
logging.info(video_path) logging.info(video_path)
return video_paths
def render_dataset(dataset, out_dir, max_num_episodes): def render_dataset(dataset, out_dir, max_num_episodes):
@ -88,7 +91,7 @@ def render_dataset(dataset, out_dir, max_num_episodes):
# add current frame to list of frames to render # add current frame to list of frames to render
frames[im_key].append(item[im_key]) frames[im_key].append(item[im_key])
end_of_episode = item["index"].item() == item["episode_data_index_to"].item() - 1 end_of_episode = item["index"].item() == dataset.episode_data_index["to"][ep_id] - 1
out_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True)
for im_key in dataset.image_keys: for im_key in dataset.image_keys:

View File

@ -0,0 +1,31 @@
import pytest
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.visualize_dataset import visualize_dataset
from .utils import DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
"dataset_id",
[
"aloha_sim_insertion_human",
],
)
def test_visualize_dataset(tmpdir, dataset_id):
# TODO(rcadene): this test might fail with other datasets/policies/envs, since visualization_dataset
# doesnt support multiple timesteps which requires delta_timestamps to None for images.
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
"policy=act",
"env=aloha",
f"dataset_id={dataset_id}",
],
)
video_paths = visualize_dataset(cfg, out_dir=tmpdir)
assert len(video_paths) > 0
for video_path in video_paths:
assert video_path.exists()