address comments
This commit is contained in:
parent
110e5afb4b
commit
db9ac59230
|
@ -34,6 +34,7 @@ hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="t
|
|||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_scripted", split="train"), 50
|
||||
```
|
||||
"""
|
||||
# TODO(rcadene): remove this example file of using hf_dataset
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -46,6 +47,7 @@ from datasets import load_dataset
|
|||
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
||||
|
||||
# display name of dataset and its features
|
||||
# TODO(rcadene): update to make the print pretty
|
||||
print(f"{hf_dataset=}")
|
||||
print(f"{hf_dataset.features=}")
|
||||
|
||||
|
|
|
@ -47,6 +47,7 @@ from lerobot.common.datasets.pusht import PushtDataset
|
|||
dataset = PushtDataset()
|
||||
|
||||
# All LeRobot datasets are actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
|
||||
# TODO(rcadene): update to make the print pretty
|
||||
print(f"{dataset=}")
|
||||
print(f"{dataset.hf_dataset=}")
|
||||
|
||||
|
@ -58,14 +59,16 @@ print(f"frames per second used during data collection: {dataset.fps=}")
|
|||
print(f"keys to access images from cameras: {dataset.image_keys=}")
|
||||
|
||||
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
|
||||
# TODO(rcadene): remove this example of accessing hf_dataset
|
||||
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
|
||||
|
||||
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grab all the image frames.
|
||||
frames = [sample["observation.image"] for sample in dataset]
|
||||
|
||||
# but frames are now float32 range [0,1] channel first to follow pytorch convention,
|
||||
# to view them, we convert to uint8 range [0,255] channel last
|
||||
# but frames are now float32 range [0,1] channel first (c,h,w) to follow pytorch convention,
|
||||
# to view them, we convert to uint8 range [0,255]
|
||||
frames = [(frame * 255).type(torch.uint8) for frame in frames]
|
||||
# and to channel last (h,w,c)
|
||||
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
||||
|
||||
# and finally save them to a mp4 video
|
||||
|
|
|
@ -12,8 +12,6 @@ from PIL import Image as PILImage
|
|||
from safetensors.torch import load_file
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.common.utils.utils import set_global_seed
|
||||
|
||||
|
||||
def flatten_dict(d, parent_key="", sep="/"):
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
|
@ -255,13 +253,15 @@ def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
|
|||
min[key] = torch.tensor(float("inf")).float()
|
||||
|
||||
def create_seeded_dataloader(hf_dataset, batch_size, seed):
|
||||
set_global_seed(seed)
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
hf_dataset,
|
||||
num_workers=4,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
generator=generator,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
|
|
|
@ -39,4 +39,14 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
|||
for _ in range(num_parallel_envs)
|
||||
]
|
||||
)
|
||||
|
||||
def preprocess_observation(observation):
|
||||
return observation / 255.0
|
||||
|
||||
def preprocess_reward(reward):
|
||||
return reward / 255.0
|
||||
|
||||
env = gym.wrappers.TransformObservation(env, preprocess_observation)
|
||||
env = gym.wrappers.TransformReward(env, preprocess_reward)
|
||||
|
||||
return env
|
||||
|
|
|
@ -26,7 +26,9 @@ def cat_and_write_video(video_path, frames, fps):
|
|||
|
||||
# Expects images in [0, 1].
|
||||
frame = frames[0]
|
||||
_, c, h, w = frame.shape
|
||||
if frame.ndim == 4:
|
||||
raise NotImplementedError("We currently dont support multiple timestamps.")
|
||||
c, h, w = frame.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {frame.shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
|
@ -55,9 +57,10 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
|||
)
|
||||
|
||||
logging.info("Start rendering episodes from offline buffer")
|
||||
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
|
||||
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER)
|
||||
for video_path in video_paths:
|
||||
logging.info(video_path)
|
||||
return video_paths
|
||||
|
||||
|
||||
def render_dataset(dataset, out_dir, max_num_episodes):
|
||||
|
@ -88,7 +91,7 @@ def render_dataset(dataset, out_dir, max_num_episodes):
|
|||
# add current frame to list of frames to render
|
||||
frames[im_key].append(item[im_key])
|
||||
|
||||
end_of_episode = item["index"].item() == item["episode_data_index_to"].item() - 1
|
||||
end_of_episode = item["index"].item() == dataset.episode_data_index["to"][ep_id] - 1
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
for im_key in dataset.image_keys:
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
import pytest
|
||||
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.scripts.visualize_dataset import visualize_dataset
|
||||
|
||||
from .utils import DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dataset_id",
|
||||
[
|
||||
"aloha_sim_insertion_human",
|
||||
],
|
||||
)
|
||||
def test_visualize_dataset(tmpdir, dataset_id):
|
||||
# TODO(rcadene): this test might fail with other datasets/policies/envs, since visualization_dataset
|
||||
# doesnt support multiple timesteps which requires delta_timestamps to None for images.
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
"policy=act",
|
||||
"env=aloha",
|
||||
f"dataset_id={dataset_id}",
|
||||
],
|
||||
)
|
||||
video_paths = visualize_dataset(cfg, out_dir=tmpdir)
|
||||
|
||||
assert len(video_paths) > 0
|
||||
|
||||
for video_path in video_paths:
|
||||
assert video_path.exists()
|
Loading…
Reference in New Issue