fix test, example 1

This commit is contained in:
Cadene 2024-05-04 11:10:16 +00:00
parent 7b3cf7e08c
commit cb73479231
1 changed files with 13 additions and 11 deletions

View File

@ -43,25 +43,27 @@ print(f"average number of frames per episode: {dataset.num_samples / dataset.num
print(f"frames per second used during data collection: {dataset.fps=}") print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.image_keys=}") print(f"keys to access images from cameras: {dataset.image_keys=}")
# While the LeRobotDataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. # access frame indices associated to episode number 5
# It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5. episode_index = 5
# TODO(rcadene): remove this example of accessing hf_dataset from_idx = dataset.episode_data_index["from"][episode_index].item()
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5) to_idx = dataset.episode_data_index["to"][episode_index].item()
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grab all the image frames. # LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter,
frames = [sample["observation.image"] for sample in dataset] # like iterating through the dataset. Here we grab all the image frames.
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
# but frames are now float32 range [0,1] channel first (c,h,w) to follow pytorch convention, # Video frames are now float32 range [0,1] channel first (c,h,w) to follow pytorch convention.
# to view them, we convert to uint8 range [0,255] # To view them, we convert to uint8 range [0,255]
frames = [(frame * 255).type(torch.uint8) for frame in frames] frames = [(frame * 255).type(torch.uint8) for frame in frames]
# and to channel last (h,w,c) # and to channel last (h,w,c).
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames] frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
# and finally save them to a mp4 video # Finally, we save the frames to a mp4 video for visualization.
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True) Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_5.mp4", frames, fps=dataset.fps) imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_5.mp4", frames, fps=dataset.fps)
# For many machine learning applications we need to load histories of past observations, or trajectorys of future actions. Our datasets can load previous and future frames for each key/modality, # For many machine learning applications we need to load histories of past observations, or trajectorys of future actions.
# Our datasets can load previous and future frames for each key/modality,
# using timestamps differences with the current loaded frame. For instance: # using timestamps differences with the current loaded frame. For instance:
delta_timestamps = { delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame # loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame