diff --git a/README.md b/README.md index d76969bc..bf8d463a 100644 --- a/README.md +++ b/README.md @@ -127,13 +127,21 @@ wandb login Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically download data from the Hugging Face hub. -You can also locally visualize episodes from a dataset by executing our script from the command line: +You can also locally visualize episodes from a dataset on the hub by executing our script from the command line: ```bash python lerobot/scripts/visualize_dataset.py \ --repo-id lerobot/pusht \ --episode-index 0 ``` +or from a dataset in a local folder with the root `DATA_DIR` environment variable (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`) +```bash +DATA_DIR='./my_local_data_dir' python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht \ + --episode-index 0 +``` + + It will open `rerun.io` and display the camera streams, robot states and actions, like this: https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144 @@ -141,6 +149,51 @@ https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-f Our script can also visualize datasets stored on a distant server. See `python lerobot/scripts/visualize_dataset.py --help` for more instructions. +### The `LeRobotDataset` format + +A dataset in `LeRobotDataset` format is very simple to use. It can be loaded from a repository on the Hugging Face hub or a local folder simply with e.g. `dataset = LeRobotDataset("lerobot/aloha_static_coffee")` and can be indexed into like any Hugging Face and PyTorch dataset. For instance `dataset[0]` will retrieve a single temporal frame from the dataset containing observation(s) and an action as PyTorch tensors ready to be fed to a model. + +A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](examples/1_load_lerobot_dataset.py) for more details on `delta_timestamps`. + +Under the hood, the `LeRobotDataset` format makes use of several ways to serialize data which can be useful to understand if you plan to work more closely with this format. We tried to make a flexible yet simple dataset format that would cover most type of features and specificities present in reinforcement learning and robotics, in simulation and in real-world, with a focus on cameras and robot states but easily extended to other types of sensory inputs as long as they can be represented by a tensor. + +Here are the important details and internal structure organization of a typical `LeRobotDataset` instantiated with `dataset = LeRobotDataset("lerobot/aloha_static_coffee")`. The exact features will change from dataset to dataset but not the main aspects: + +``` +dataset attributes: + ├ hf_dataset: a Hugging Face dataset (backed by Arrow/parquet). Typical features example: + │ ├ observation.images.cam_high (VideoFrame): + │ │ VideoFrame = {'path': path to a mp4 video, 'timestamp' (float32): timestamp in the video} + │ ├ observation.state (list of float32): position of an arm joints (for instance) + │ ... (more observations) + │ ├ action (list of float32): goal position of an arm joints (for instance) + │ ├ episode_index (int64): index of the episode for this sample + │ ├ frame_index (int64): index of the frame for this sample in the episode ; starts at 0 for each episode + │ ├ timestamp (float32): timestamp in the episode + │ ├ next.done (bool): indicates the end of en episode ; True for the last frame in each episode + │ └ index (int64): general index in the whole dataset + ├ episode_data_index: contains 2 tensors with the start and end indices of each episode + │ ├ from (1D int64 tensor): first frame index for each episode — shape (num episodes,) starts with 0 + │ └ to: (1D int64 tensor): last frame index for each episode — shape (num episodes,) + ├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance + │ ├ observation.images.cam_high: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.} + │ ... + ├ info: a dictionary of metadata on the dataset + │ ├ fps (float): frame per second the dataset is recorded/synchronized to + │ └ video (bool): indicates if frames are encoded in mp4 video files to save space or stored as png files + ├ videos_dir (Path): where the mp4 videos or png images are stored/accessed + └ camera_keys (list of string): the keys to access camera features in the item returned by the dataset (e.g. `["observation.images.cam_high", ...]`) +``` + +A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely: +- hf_dataset stored using Hugging Face datasets library serialization to parquet +- videos are stored in mp4 format to save space or png files +- episode_data_index saved using `safetensor` tensor serialization format +- stats saved using `safetensor` tensor serialization format +- info are saved using JSON + +Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can set the `DATA_DIR` environment variable to your root dataset folder as illustrated in the above section on dataset visualization. + ### Evaluate a pretrained policy Check out [example 2](./examples/2_evaluate_pretrained_policy.py) that illustrates how to download a pretrained policy from Hugging Face hub, and run an evaluation on its corresponding environment. diff --git a/docker/lerobot-gpu-dev/Dockerfile b/docker/lerobot-gpu-dev/Dockerfile index e5c7d454..94b4f3ac 100644 --- a/docker/lerobot-gpu-dev/Dockerfile +++ b/docker/lerobot-gpu-dev/Dockerfile @@ -1,4 +1,4 @@ -FROM nvidia/cuda:12.4.1-base-ubuntu22.04 +FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 # Configure image ARG PYTHON_VERSION=3.10 @@ -10,9 +10,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ git git-lfs openssh-client \ nano vim less util-linux \ htop atop nvtop \ - sed gawk grep curl wget \ + sed gawk grep curl wget zip unzip \ tcpdump sysstat screen tmux \ - libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \ + libglib2.0-0 libgl1-mesa-glx libegl1-mesa \ python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \ && apt-get clean && rm -rf /var/lib/apt/lists/* diff --git a/lerobot/common/datasets/_video_benchmark/capture_camera_feed.py b/lerobot/common/datasets/_video_benchmark/capture_camera_feed.py new file mode 100644 index 00000000..3b4c356a --- /dev/null +++ b/lerobot/common/datasets/_video_benchmark/capture_camera_feed.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Capture video feed from a camera as raw images.""" + +import argparse +import datetime as dt +from pathlib import Path + +import cv2 + + +def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int): + now = dt.datetime.now() + capture_dir = output_dir / f"{now:%Y-%m-%d}" / f"{now:%H-%M-%S}" + if not capture_dir.exists(): + capture_dir.mkdir(parents=True, exist_ok=True) + + # Opens the default webcam + cap = cv2.VideoCapture(0) + if not cap.isOpened(): + print("Error: Could not open video stream.") + return + + cap.set(cv2.CAP_PROP_FPS, fps) + cap.set(cv2.CAP_PROP_FRAME_WIDTH, width) + cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height) + + frame_index = 0 + while True: + ret, frame = cap.read() + + if not ret: + print("Error: Could not read frame.") + break + + cv2.imshow("Video Stream", frame) + cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame) + frame_index += 1 + + # Break the loop on 'q' key press + if cv2.waitKey(1) & 0xFF == ord("q"): + break + + # Release the capture and destroy all windows + cap.release() + cv2.destroyAllWindows() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--output-dir", + type=Path, + default=Path("outputs/cam_capture/"), + help="Directory where the capture images are written. A subfolder named with the current date & time will be created inside it for each capture.", + ) + parser.add_argument( + "--fps", + type=int, + default=30, + help="Frames Per Second of the capture.", + ) + parser.add_argument( + "--width", + type=int, + default=1280, + help="Width of the captured images.", + ) + parser.add_argument( + "--height", + type=int, + default=720, + help="Height of the captured images.", + ) + args = parser.parse_args() + display_and_save_video_stream(**vars(args)) diff --git a/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py b/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py index 8be251dc..d92b5eef 100644 --- a/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py +++ b/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py @@ -13,6 +13,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Assess the performance of video decoding in various configurations. + +This script will run different video decoding benchmarks where one parameter varies at a time. +These parameters and theirs values are specified in the BENCHMARKS dict. + +All of these benchmarks are evaluated within different timestamps modes corresponding to different frame-loading scenarios: + - `1_frame`: 1 single frame is loaded. + - `2_frames`: 2 consecutive frames are loaded. + - `2_frames_4_space`: 2 frames separated by 4 frames are loaded. + - `6_frames`: 6 consecutive frames are loaded. + +These values are more or less arbitrary and based on possible future usage. + +These benchmarks are run on the first episode of each dataset specified in DATASET_REPO_IDS. +Note: These datasets need to be image datasets, not video datasets. +""" + import json import random import shutil @@ -21,15 +38,38 @@ import time from pathlib import Path import einops -import numpy +import numpy as np import PIL import torch +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.video_utils import ( decode_video_frames_torchvision, ) +OUTPUT_DIR = Path("tmp/run_video_benchmark") +DRY_RUN = False + +DATASET_REPO_IDS = [ + "lerobot/pusht_image", + "aliberts/aloha_mobile_shrimp_image", + "aliberts/paris_street", + "aliberts/kitchen", +] +TIMESTAMPS_MODES = [ + "1_frame", + "2_frames", + "2_frames_4_space", + "6_frames", +] +BENCHMARKS = { + # "pix_fmt": ["yuv420p", "yuv444p"], + # "g": [1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None], + # "crf": [0, 5, 10, 15, 20, None, 25, 30, 40, 50], + "backend": ["pyav", "video_reader"], +} + def get_directory_size(directory): total_size = 0 @@ -56,6 +96,10 @@ def run_video_benchmark( # TODO(rcadene): rewrite with hardcoding of original images and episodes dataset = LeRobotDataset(repo_id) + if dataset.video: + raise ValueError( + f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}" + ) # Get fps fps = dataset.fps @@ -68,10 +112,11 @@ def run_video_benchmark( if not imgs_dir.exists(): imgs_dir.mkdir(parents=True, exist_ok=True) hf_dataset = dataset.hf_dataset.with_format(None) - imgs_dataset = hf_dataset.select_columns("observation.image") + img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")] + imgs_dataset = hf_dataset.select_columns(img_keys[0]) for i, item in enumerate(imgs_dataset): - img = item["observation.image"] + img = item[img_keys[0]] img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100) if i >= ep_num_images - 1: @@ -107,7 +152,7 @@ def run_video_benchmark( decoder = cfg["decoder"] decoder_kwgs = cfg["decoder_kwgs"] - device = cfg["device"] + backend = cfg["backend"] if decoder == "torchvision": decode_frames_fn = decode_video_frames_torchvision @@ -116,12 +161,12 @@ def run_video_benchmark( # Estimate average loading time - def load_original_frames(imgs_dir, timestamps): + def load_original_frames(imgs_dir, timestamps) -> torch.Tensor: frames = [] for ts in timestamps: idx = int(ts * fps) frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png") - frame = torch.from_numpy(numpy.array(frame)) + frame = torch.from_numpy(np.array(frame)) frame = frame.type(torch.float32) / 255 frame = einops.rearrange(frame, "h w c -> c h w") frames.append(frame) @@ -130,6 +175,9 @@ def run_video_benchmark( list_avg_load_time = [] list_avg_load_time_from_images = [] per_pixel_l2_errors = [] + psnr_values = [] + ssim_values = [] + mse_values = [] random.seed(seed) @@ -142,7 +190,7 @@ def run_video_benchmark( elif timestamps_mode == "2_frames": timestamps = [ts - 1 / fps, ts] elif timestamps_mode == "2_frames_4_space": - timestamps = [ts - 4 / fps, ts] + timestamps = [ts - 5 / fps, ts] elif timestamps_mode == "6_frames": timestamps = [ts - i / fps for i in range(6)][::-1] else: @@ -152,7 +200,7 @@ def run_video_benchmark( start_time_s = time.monotonic() frames = decode_frames_fn( - video_path, timestamps=timestamps, tolerance_s=1e-4, device=device, **decoder_kwgs + video_path, timestamps=timestamps, tolerance_s=1e-4, backend=backend, **decoder_kwgs ) avg_load_time = (time.monotonic() - start_time_s) / num_frames list_avg_load_time.append(avg_load_time) @@ -162,11 +210,19 @@ def run_video_benchmark( avg_load_time_from_images = (time.monotonic() - start_time_s) / num_frames list_avg_load_time_from_images.append(avg_load_time_from_images) - # Estimate average L2 error between original frames and decoded frames + # Estimate reconstruction error between original frames and decoded frames with various metrics for i, ts in enumerate(timestamps): # are_close = torch.allclose(frames[i], original_frames[i], atol=0.02) num_pixels = original_frames[i].numel() per_pixel_l2_error = torch.norm(frames[i] - original_frames[i], p=2).item() / num_pixels + per_pixel_l2_errors.append(per_pixel_l2_error) + + frame_np, original_frame_np = frames[i].numpy(), original_frames[i].numpy() + psnr_values.append(peak_signal_noise_ratio(original_frame_np, frame_np, data_range=1.0)) + ssim_values.append( + structural_similarity(original_frame_np, frame_np, data_range=1.0, channel_axis=0) + ) + mse_values.append(mean_squared_error(original_frame_np, frame_np)) # save decoded frames if t == 0: @@ -179,15 +235,18 @@ def run_video_benchmark( original_frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png") original_frame.save(output_dir / f"original_frame_{i:06d}.png") - per_pixel_l2_errors.append(per_pixel_l2_error) - - avg_load_time = float(numpy.array(list_avg_load_time).mean()) - avg_load_time_from_images = float(numpy.array(list_avg_load_time_from_images).mean()) - avg_per_pixel_l2_error = float(numpy.array(per_pixel_l2_errors).mean()) + image_size = tuple(dataset[0][dataset.camera_keys[0]].shape[-2:]) + avg_load_time = float(np.array(list_avg_load_time).mean()) + avg_load_time_from_images = float(np.array(list_avg_load_time_from_images).mean()) + avg_per_pixel_l2_error = float(np.array(per_pixel_l2_errors).mean()) + avg_psnr = float(np.mean(psnr_values)) + avg_ssim = float(np.mean(ssim_values)) + avg_mse = float(np.mean(mse_values)) # Save benchmark info info = { + "image_size": image_size, "sum_original_frames_size_bytes": sum_original_frames_size_bytes, "video_size_bytes": video_size_bytes, "avg_load_time_from_images": avg_load_time_from_images, @@ -195,6 +254,9 @@ def run_video_benchmark( "compression_factor": sum_original_frames_size_bytes / video_size_bytes, "load_time_factor": avg_load_time_from_images / avg_load_time, "avg_per_pixel_l2_error": avg_per_pixel_l2_error, + "avg_psnr": avg_psnr, + "avg_ssim": avg_ssim, + "avg_mse": avg_mse, } with open(output_dir / "info.json", "w") as f: @@ -234,138 +296,113 @@ def load_info(out_dir): return info -def main(): - out_dir = Path("tmp/run_video_benchmark") - dry_run = False - repo_ids = ["lerobot/pusht", "lerobot/umi_cup_in_the_wild"] - timestamps_modes = [ - "1_frame", - "2_frames", - "2_frames_4_space", - "6_frames", +def one_variable_study( + var_name: str, var_values: list, repo_ids: list, bench_dir: Path, timestamps_mode: str, dry_run: bool +): + print(f"**`{var_name}`**") + headers = [ + "repo_id", + "image_size", + var_name, + "compression_factor", + "load_time_factor", + "avg_per_pixel_l2_error", + "avg_psnr", + "avg_ssim", + "avg_mse", ] - for timestamps_mode in timestamps_modes: - bench_dir = out_dir / timestamps_mode + rows = [] + base_cfg = { + "repo_id": None, + # video encoding + "g": 2, + "crf": None, + "pix_fmt": "yuv444p", + # video decoding + "backend": "pyav", + "decoder": "torchvision", + "decoder_kwgs": {}, + } + for repo_id in repo_ids: + for val in var_values: + cfg = base_cfg.copy() + cfg["repo_id"] = repo_id + cfg[var_name] = val + if not dry_run: + run_video_benchmark( + bench_dir / repo_id / f"torchvision_{var_name}_{val}", cfg, timestamps_mode + ) + info = load_info(bench_dir / repo_id / f"torchvision_{var_name}_{val}") + width, height = info["image_size"][0], info["image_size"][1] + rows.append( + [ + repo_id, + f"{width} x {height}", + val, + info["compression_factor"], + info["load_time_factor"], + info["avg_per_pixel_l2_error"], + info["avg_psnr"], + info["avg_ssim"], + info["avg_mse"], + ] + ) + display_markdown_table(headers, rows) + + +def best_study(repo_ids: list, bench_dir: Path, timestamps_mode: str, dry_run: bool): + """Change the config once you deciced what's best based on one-variable-studies""" + print("**best**") + headers = [ + "repo_id", + "image_size", + "compression_factor", + "load_time_factor", + "avg_per_pixel_l2_error", + "avg_psnr", + "avg_ssim", + "avg_mse", + ] + rows = [] + for repo_id in repo_ids: + cfg = { + "repo_id": repo_id, + # video encoding + "g": 2, + "crf": None, + "pix_fmt": "yuv444p", + # video decoding + "backend": "video_reader", + "decoder": "torchvision", + "decoder_kwgs": {}, + } + if not dry_run: + run_video_benchmark(bench_dir / repo_id / "torchvision_best", cfg, timestamps_mode) + info = load_info(bench_dir / repo_id / "torchvision_best") + width, height = info["image_size"][0], info["image_size"][1] + rows.append( + [ + repo_id, + f"{width} x {height}", + info["compression_factor"], + info["load_time_factor"], + info["avg_per_pixel_l2_error"], + ] + ) + display_markdown_table(headers, rows) + + +def main(): + for timestamps_mode in TIMESTAMPS_MODES: + bench_dir = OUTPUT_DIR / timestamps_mode print(f"### `{timestamps_mode}`") print() - print("**`pix_fmt`**") - headers = ["repo_id", "pix_fmt", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"] - rows = [] - for repo_id in repo_ids: - for pix_fmt in ["yuv420p", "yuv444p"]: - cfg = { - "repo_id": repo_id, - # video encoding - "g": 2, - "crf": None, - "pix_fmt": pix_fmt, - # video decoding - "device": "cpu", - "decoder": "torchvision", - "decoder_kwgs": {}, - } - if not dry_run: - run_video_benchmark(bench_dir / repo_id / f"torchvision_{pix_fmt}", cfg, timestamps_mode) - info = load_info(bench_dir / repo_id / f"torchvision_{pix_fmt}") - rows.append( - [ - repo_id, - pix_fmt, - info["compression_factor"], - info["load_time_factor"], - info["avg_per_pixel_l2_error"], - ] - ) - display_markdown_table(headers, rows) + for name, values in BENCHMARKS.items(): + one_variable_study(name, values, DATASET_REPO_IDS, bench_dir, timestamps_mode, DRY_RUN) - print("**`g`**") - headers = ["repo_id", "g", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"] - rows = [] - for repo_id in repo_ids: - for g in [1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None]: - cfg = { - "repo_id": repo_id, - # video encoding - "g": g, - "pix_fmt": "yuv444p", - # video decoding - "device": "cpu", - "decoder": "torchvision", - "decoder_kwgs": {}, - } - if not dry_run: - run_video_benchmark(bench_dir / repo_id / f"torchvision_g_{g}", cfg, timestamps_mode) - info = load_info(bench_dir / repo_id / f"torchvision_g_{g}") - rows.append( - [ - repo_id, - g, - info["compression_factor"], - info["load_time_factor"], - info["avg_per_pixel_l2_error"], - ] - ) - display_markdown_table(headers, rows) - - print("**`crf`**") - headers = ["repo_id", "crf", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"] - rows = [] - for repo_id in repo_ids: - for crf in [0, 5, 10, 15, 20, None, 25, 30, 40, 50]: - cfg = { - "repo_id": repo_id, - # video encoding - "g": 2, - "crf": crf, - "pix_fmt": "yuv444p", - # video decoding - "device": "cpu", - "decoder": "torchvision", - "decoder_kwgs": {}, - } - if not dry_run: - run_video_benchmark(bench_dir / repo_id / f"torchvision_crf_{crf}", cfg, timestamps_mode) - info = load_info(bench_dir / repo_id / f"torchvision_crf_{crf}") - rows.append( - [ - repo_id, - crf, - info["compression_factor"], - info["load_time_factor"], - info["avg_per_pixel_l2_error"], - ] - ) - display_markdown_table(headers, rows) - - print("**best**") - headers = ["repo_id", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"] - rows = [] - for repo_id in repo_ids: - cfg = { - "repo_id": repo_id, - # video encoding - "g": 2, - "crf": None, - "pix_fmt": "yuv444p", - # video decoding - "device": "cpu", - "decoder": "torchvision", - "decoder_kwgs": {}, - } - if not dry_run: - run_video_benchmark(bench_dir / repo_id / "torchvision_best", cfg, timestamps_mode) - info = load_info(bench_dir / repo_id / "torchvision_best") - rows.append( - [ - repo_id, - info["compression_factor"], - info["load_time_factor"], - info["avg_per_pixel_l2_error"], - ] - ) - display_markdown_table(headers, rows) + # best_study(DATASET_REPO_IDS, bench_dir, timestamps_mode, DRY_RUN) if __name__ == "__main__": diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 754bc91b..96a353fb 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -96,6 +96,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData split=split, delta_timestamps=cfg.training.get("delta_timestamps"), image_transforms=image_transforms, + video_backend=cfg.video_backend, ) else: dataset = MultiLeRobotDataset( @@ -103,6 +104,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData split=split, delta_timestamps=cfg.training.get("delta_timestamps"), image_transforms=image_transforms, + video_backend=cfg.video_backend, ) if cfg.get("override_dataset_stats"): diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index d680b987..61c35aa4 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -48,6 +48,7 @@ class LeRobotDataset(torch.utils.data.Dataset): split: str = "train", image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, + video_backend: str | None = None, ): super().__init__() self.repo_id = repo_id @@ -69,6 +70,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.info = load_info(repo_id, version, root) if self.video: self.videos_dir = load_videos(repo_id, version, root) + self.video_backend = video_backend if video_backend is not None else "pyav" @property def fps(self) -> int: @@ -149,6 +151,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.video_frame_keys, self.videos_dir, self.tolerance_s, + self.video_backend, ) if self.image_transforms is not None: @@ -188,6 +191,7 @@ class LeRobotDataset(torch.utils.data.Dataset): stats=None, info=None, videos_dir=None, + video_backend=None, ) -> "LeRobotDataset": """Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem. @@ -210,6 +214,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.stats = stats obj.info = info if info is not None else {} obj.videos_dir = videos_dir + obj.video_backend = video_backend if video_backend is not None else "pyav" return obj @@ -228,6 +233,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): split: str = "train", image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, + video_backend: str | None = None, ): super().__init__() self.repo_ids = repo_ids @@ -241,6 +247,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): split=split, delta_timestamps=delta_timestamps, image_transforms=image_transforms, + video_backend=video_backend, ) for repo_id in repo_ids ] diff --git a/lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py b/lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py new file mode 100644 index 00000000..4972e6b4 --- /dev/null +++ b/lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Contains utilities to process raw data format of png images files recorded with capture_camera_feed.py +""" + +from pathlib import Path + +import torch +from datasets import Dataset, Features, Image, Value +from PIL import Image as PILImage + +from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes +from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch +from lerobot.common.datasets.video_utils import VideoFrame + + +def check_format(raw_dir: Path) -> bool: + image_paths = list(raw_dir.glob("frame_*.png")) + if len(image_paths) == 0: + raise ValueError + + +def load_from_raw(raw_dir: Path, fps: int, episodes: list[int] | None = None): + if episodes is not None: + # TODO(aliberts): add support for multi-episodes. + raise NotImplementedError() + + ep_dict = {} + ep_idx = 0 + + image_paths = sorted(raw_dir.glob("frame_*.png")) + num_frames = len(image_paths) + + ep_dict["observation.image"] = [PILImage.open(x) for x in image_paths] + ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames) + ep_dict["frame_index"] = torch.arange(0, num_frames, 1) + ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps + + ep_dicts = [ep_dict] + data_dict = concatenate_episodes(ep_dicts) + total_frames = data_dict["frame_index"].shape[0] + data_dict["index"] = torch.arange(0, total_frames, 1) + return data_dict + + +def to_hf_dataset(data_dict, video) -> Dataset: + features = {} + if video: + features["observation.image"] = VideoFrame() + else: + features["observation.image"] = Image() + + features["episode_index"] = Value(dtype="int64", id=None) + features["frame_index"] = Value(dtype="int64", id=None) + features["timestamp"] = Value(dtype="float32", id=None) + features["index"] = Value(dtype="int64", id=None) + + hf_dataset = Dataset.from_dict(data_dict, features=Features(features)) + hf_dataset.set_transform(hf_transform_to_torch) + return hf_dataset + + +def from_raw_to_lerobot_format( + raw_dir: Path, + videos_dir: Path, + fps: int | None = None, + video: bool = True, + episodes: list[int] | None = None, +): + if video or episodes is not None: + # TODO(aliberts): support this + raise NotImplementedError + + # sanity check + check_format(raw_dir) + + if fps is None: + fps = 30 + + data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes) + hf_dataset = to_hf_dataset(data_dict, video) + episode_data_index = calculate_episode_data_index(hf_dataset) + info = { + "fps": fps, + "video": video, + } + return hf_dataset, episode_data_index, info diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index edfca918..fdc4fbe9 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -27,7 +27,11 @@ from datasets.features.features import register_feature def load_from_videos( - item: dict[str, torch.Tensor], video_frame_keys: list[str], videos_dir: Path, tolerance_s: float + item: dict[str, torch.Tensor], + video_frame_keys: list[str], + videos_dir: Path, + tolerance_s: float, + backend: str = "pyav", ): """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault. @@ -46,14 +50,14 @@ def load_from_videos( raise NotImplementedError("All video paths are expected to be the same for now.") video_path = data_dir / paths[0] - frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s) + frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) item[key] = frames else: # load one frame timestamps = [item[key]["timestamp"]] video_path = data_dir / item[key]["path"] - frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s) + frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) item[key] = frames[0] return item @@ -63,11 +67,23 @@ def decode_video_frames_torchvision( video_path: str, timestamps: list[float], tolerance_s: float, - device: str = "cpu", + backend: str = "pyav", log_loaded_timestamps: bool = False, ): """Loads frames associated to the requested timestamps of a video + The backend can be either "pyav" (default) or "video_reader". + "video_reader" requires installing torchvision from source, see: + https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst + (note that you need to compile against ffmpeg<4.3) + + While both use cpu, "video_reader" is faster than "pyav" but requires additional setup. + See our benchmark results for more info on performance: + https://github.com/huggingface/lerobot/pull/220 + + See torchvision doc for more info on these two backends: + https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend + Note: Video benefits from inter-frame compression. Instead of storing every frame individually, the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame, @@ -78,21 +94,9 @@ def decode_video_frames_torchvision( # set backend keyframes_only = False - if device == "cpu": - # explicitely use pyav - torchvision.set_video_backend("pyav") + torchvision.set_video_backend(backend) + if backend == "pyav": keyframes_only = True # pyav doesnt support accuracte seek - elif device == "cuda": - # TODO(rcadene, aliberts): implement video decoding with GPU - # torchvision.set_video_backend("cuda") - # torchvision.set_video_backend("video_reader") - # requires installing torchvision from source, see: https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst - # check possible bug: https://github.com/pytorch/vision/issues/7745 - raise NotImplementedError( - "Video decoding on gpu with cuda is currently not supported. Use `device='cpu'`." - ) - else: - raise ValueError(device) # set a video stream reader # TODO(rcadene): also load audio stream at the same time @@ -120,7 +124,9 @@ def decode_video_frames_torchvision( if current_ts >= last_ts: break - reader.container.close() + if backend == "pyav": + reader.container.close() + reader = None query_ts = torch.tensor(timestamps) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index bef59bec..5f302bc7 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -314,9 +314,23 @@ class ACT(nn.Module): # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D) + # Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the + # sequence depending whether we use the input states or not (cls and robot state) + # False means not a padding token. + cls_joint_is_pad = torch.full( + (batch_size, 2 if self.use_input_state else 1), + False, + device=batch["observation.state"].device, + ) + key_padding_mask = torch.cat( + [cls_joint_is_pad, batch["action_is_pad"]], axis=1 + ) # (bs, seq+1 or 2) + # Forward pass through VAE encoder to get the latent PDF parameters. cls_token_out = self.vae_encoder( - vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2) + vae_encoder_input.permute(1, 0, 2), + pos_embed=pos_embed.permute(1, 0, 2), + key_padding_mask=key_padding_mask, )[0] # select the class token, with shape (B, D) latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) mu = latent_pdf_params[:, : self.config.latent_dim] @@ -402,9 +416,11 @@ class ACTEncoder(nn.Module): self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)]) self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() - def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor: + def forward( + self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None + ) -> Tensor: for layer in self.layers: - x = layer(x, pos_embed=pos_embed) + x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask) x = self.norm(x) return x @@ -427,12 +443,13 @@ class ACTEncoderLayer(nn.Module): self.activation = get_activation_fn(config.feedforward_activation) self.pre_norm = config.pre_norm - def forward(self, x, pos_embed: Tensor | None = None) -> Tensor: + def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor: skip = x if self.pre_norm: x = self.norm1(x) q = k = x if pos_embed is None else x + pos_embed - x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights + x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask) + x = x[0] # note: [0] to select just the output, not the attention weights x = skip + self.dropout1(x) if self.pre_norm: skip = x diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 6101df89..df0dae7d 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -28,6 +28,7 @@ seed: ??? # "dataset_index" into the returned item. The index mapping is made according to the order in which the # datsets are provided. dataset_repo_id: lerobot/pusht +video_backend: pyav training: offline_steps: ??? @@ -38,9 +39,10 @@ training: # `online_env_seed` is used for environments for online training data rollouts. online_env_seed: ??? eval_freq: ??? - save_freq: ??? log_freq: 250 save_checkpoint: true + # Checkpoint is saved every `save_freq` training iterations and after the last training step. + save_freq: ??? num_workers: 4 batch_size: ??? image_transforms: diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index 94cc960e..e473aa0c 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -55,7 +55,6 @@ from safetensors.torch import save_file from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset -from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw from lerobot.common.datasets.utils import flatten_dict @@ -72,6 +71,8 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str): from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format elif raw_format == "xarm_pkl": from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format + elif raw_format == "cam_png": + from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import from_raw_to_lerobot_format else: raise ValueError( f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?" @@ -184,10 +185,6 @@ def push_dataset_to_hub( meta_data_dir = Path(cache_dir) / "meta_data" videos_dir = Path(cache_dir) / "videos" - # Download the raw dataset if available - if not raw_dir.exists(): - download_raw(raw_dir, dataset_id) - if raw_format is None: # TODO(rcadene, adilzouitine): implement auto_find_raw_format raise NotImplementedError() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 01b2ef4f..796881c4 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -53,12 +53,14 @@ def make_optimizer_and_scheduler(cfg, policy): "params": [ p for n, p in policy.named_parameters() - if not n.startswith("backbone") and p.requires_grad + if not n.startswith("model.backbone") and p.requires_grad ] }, { "params": [ - p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad + p + for n, p in policy.named_parameters() + if n.startswith("model.backbone") and p.requires_grad ], "lr": cfg.training.lr_backbone, }, @@ -349,7 +351,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logger.log_video(eval_info["video_paths"][0], step, mode="eval") logging.info("Resume training") - if cfg.training.save_checkpoint and step % cfg.training.save_freq == 0: + if cfg.training.save_checkpoint and ( + step % cfg.training.save_freq == 0 + or step == cfg.training.offline_steps + cfg.training.online_steps + ): logging.info(f"Checkpoint policy after step {step}") # Note: Save with step as the identifier, and format it to have at least 6 digits but more if # needed (choose 6 as a minimum for consistency without being overkill). diff --git a/poetry.lock b/poetry.lock index e9d6e848..48c6c057 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -3805,31 +3805,31 @@ files = [ [[package]] name = "torch" -version = "2.3.0" +version = "2.3.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:d8ea5a465dbfd8501f33c937d1f693176c9aef9d1c1b0ca1d44ed7b0a18c52ac"}, - {file = "torch-2.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:09c81c5859a5b819956c6925a405ef1cdda393c9d8a01ce3851453f699d3358c"}, - {file = "torch-2.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:1bf023aa20902586f614f7682fedfa463e773e26c58820b74158a72470259459"}, - {file = "torch-2.3.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:758ef938de87a2653bba74b91f703458c15569f1562bf4b6c63c62d9c5a0c1f5"}, - {file = "torch-2.3.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:493d54ee2f9df100b5ce1d18c96dbb8d14908721f76351e908c9d2622773a788"}, - {file = "torch-2.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:bce43af735c3da16cc14c7de2be7ad038e2fbf75654c2e274e575c6c05772ace"}, - {file = "torch-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:729804e97b7cf19ae9ab4181f91f5e612af07956f35c8b2c8e9d9f3596a8e877"}, - {file = "torch-2.3.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:d24e328226d8e2af7cf80fcb1d2f1d108e0de32777fab4aaa2b37b9765d8be73"}, - {file = "torch-2.3.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b0de2bdc0486ea7b14fc47ff805172df44e421a7318b7c4d92ef589a75d27410"}, - {file = "torch-2.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a306c87a3eead1ed47457822c01dfbd459fe2920f2d38cbdf90de18f23f72542"}, - {file = "torch-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:f9b98bf1a3c8af2d4c41f0bf1433920900896c446d1ddc128290ff146d1eb4bd"}, - {file = "torch-2.3.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:dca986214267b34065a79000cee54232e62b41dff1ec2cab9abc3fc8b3dee0ad"}, - {file = "torch-2.3.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:20572f426965dd8a04e92a473d7e445fa579e09943cc0354f3e6fef6130ce061"}, - {file = "torch-2.3.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e65ba85ae292909cde0dde6369826d51165a3fc8823dc1854cd9432d7f79b932"}, - {file = "torch-2.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:5515503a193781fd1b3f5c474e89c9dfa2faaa782b2795cc4a7ab7e67de923f6"}, - {file = "torch-2.3.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:6ae9f64b09516baa4ef890af0672dc981c20b1f0d829ce115d4420a247e88fba"}, - {file = "torch-2.3.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:cd0dc498b961ab19cb3f8dbf0c6c50e244f2f37dbfa05754ab44ea057c944ef9"}, - {file = "torch-2.3.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:e05f836559251e4096f3786ee99f4a8cbe67bc7fbedba8ad5e799681e47c5e80"}, - {file = "torch-2.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:4fb27b35dbb32303c2927da86e27b54a92209ddfb7234afb1949ea2b3effffea"}, - {file = "torch-2.3.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:760f8bedff506ce9e6e103498f9b1e9e15809e008368594c3a66bf74a8a51380"}, + {file = "torch-2.3.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:605a25b23944be5ab7c3467e843580e1d888b8066e5aaf17ff7bf9cc30001cc3"}, + {file = "torch-2.3.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f2357eb0965583a0954d6f9ad005bba0091f956aef879822274b1bcdb11bd308"}, + {file = "torch-2.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:32b05fe0d1ada7f69c9f86c14ff69b0ef1957a5a54199bacba63d22d8fab720b"}, + {file = "torch-2.3.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:7c09a94362778428484bcf995f6004b04952106aee0ef45ff0b4bab484f5498d"}, + {file = "torch-2.3.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:b2ec81b61bb094ea4a9dee1cd3f7b76a44555375719ad29f05c0ca8ef596ad39"}, + {file = "torch-2.3.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:490cc3d917d1fe0bd027057dfe9941dc1d6d8e3cae76140f5dd9a7e5bc7130ab"}, + {file = "torch-2.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:5802530783bd465fe66c2df99123c9a54be06da118fbd785a25ab0a88123758a"}, + {file = "torch-2.3.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:a7dd4ed388ad1f3d502bf09453d5fe596c7b121de7e0cfaca1e2017782e9bbac"}, + {file = "torch-2.3.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:a486c0b1976a118805fc7c9641d02df7afbb0c21e6b555d3bb985c9f9601b61a"}, + {file = "torch-2.3.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:224259821fe3e4c6f7edf1528e4fe4ac779c77addaa74215eb0b63a5c474d66c"}, + {file = "torch-2.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:e5fdccbf6f1334b2203a61a0e03821d5845f1421defe311dabeae2fc8fbeac2d"}, + {file = "torch-2.3.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:3c333dc2ebc189561514eda06e81df22bf8fb64e2384746b2cb9f04f96d1d4c8"}, + {file = "torch-2.3.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:07e9ba746832b8d069cacb45f312cadd8ad02b81ea527ec9766c0e7404bb3feb"}, + {file = "torch-2.3.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:462d1c07dbf6bb5d9d2f3316fee73a24f3d12cd8dacf681ad46ef6418f7f6626"}, + {file = "torch-2.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:ff60bf7ce3de1d43ad3f6969983f321a31f0a45df3690921720bcad6a8596cc4"}, + {file = "torch-2.3.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:bee0bd33dc58aa8fc8a7527876e9b9a0e812ad08122054a5bff2ce5abf005b10"}, + {file = "torch-2.3.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:aaa872abde9a3d4f91580f6396d54888620f4a0b92e3976a6034759df4b961ad"}, + {file = "torch-2.3.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:3d7a7f7ef21a7520510553dc3938b0c57c116a7daee20736a9e25cbc0e832bdc"}, + {file = "torch-2.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:4777f6cefa0c2b5fa87223c213e7b6f417cf254a45e5829be4ccd1b2a4ee1011"}, + {file = "torch-2.3.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:2bb5af780c55be68fe100feb0528d2edebace1d55cb2e351de735809ba7391eb"}, ] [package.dependencies] @@ -3850,7 +3850,7 @@ nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \" nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} sympy = "*" -triton = {version = "2.3.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} +triton = {version = "2.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} typing-extensions = ">=4.8.0" [package.extras] @@ -3859,37 +3859,37 @@ optree = ["optree (>=0.9.1)"] [[package]] name = "torchvision" -version = "0.18.0" +version = "0.18.1" description = "image and video datasets and models for torch deep learning" optional = false python-versions = ">=3.8" files = [ - {file = "torchvision-0.18.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dd61628a3d189c6852a12dc5ed4cd2eece66d2d67f35a866cb16f1dcb06c8c62"}, - {file = "torchvision-0.18.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:493c45f9937dad37aa1b64b14da17c7a589c72b91adc4837d431009cfe29bd53"}, - {file = "torchvision-0.18.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:5337f6acfa1fe959d5cb340d01a00614d6b31ce7a4824ccb95435a85c5273b95"}, - {file = "torchvision-0.18.0-cp310-cp310-win_amd64.whl", hash = "sha256:bd8e6f3b5beb49965f15c461302488edfa3d8c2d01d3bb79b150d6fb62711e3a"}, - {file = "torchvision-0.18.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6896a52168befe1105fb3c9335287390ed227e71d1e4ec4d68b62e8a3099fc09"}, - {file = "torchvision-0.18.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:3d7955398d4ceaad77c487c2c44f6f7813112402c9bab8cd906d346005891048"}, - {file = "torchvision-0.18.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e5a24d620cea14a4bb89f24aa2b506230c0a16a3ada57fc53ad80cfd256a2128"}, - {file = "torchvision-0.18.0-cp311-cp311-win_amd64.whl", hash = "sha256:6ad70ddfa879bda5ed886b2518fe562640e0059787cbd65cb2bffa7674541410"}, - {file = "torchvision-0.18.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:eb9d83c0e1dbb54ecb0fb04c87f786333e3a6fb8b9c400aca7c31081f9aa5707"}, - {file = "torchvision-0.18.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b657d052d146f24cb3b2a78219bfc82ae70a9706671c50f632528907d10cccec"}, - {file = "torchvision-0.18.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a964afbc7ddf50a46b941477f6c35729b416deedd139756befd488245e2e226d"}, - {file = "torchvision-0.18.0-cp312-cp312-win_amd64.whl", hash = "sha256:7c770f0f748e0b17f57c0297508d7254f686cdf03fc2e2949f422b20574f4c0f"}, - {file = "torchvision-0.18.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2115a1906c015f5da9ceedc40a983313b0fd6e2c8a17108a92991706f51f6987"}, - {file = "torchvision-0.18.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:6323f7e5423ff2594d5891863b919deb9d0de95f01c36bf26fbd879036b6ed08"}, - {file = "torchvision-0.18.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:925d0a82cccf6f986c18b29b4392a942db65cbdb73c13a129c8493822eb9e36f"}, - {file = "torchvision-0.18.0-cp38-cp38-win_amd64.whl", hash = "sha256:95b42d0dc599b47a01530c7439a5751e67e45b85e3a67113989cf7c7c70f2039"}, - {file = "torchvision-0.18.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:75e22ecf44a13b8f95b8ad421c0261282d859c61816badaca1959e073ccdd691"}, - {file = "torchvision-0.18.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:4c334b3e719ba0a9ba6e15d4aff1178f5e6d029174f346163fed525f0ccfffd3"}, - {file = "torchvision-0.18.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:36efd87001c6bee2383e043e46a025affb03179747c8f4777b9918527ffce756"}, - {file = "torchvision-0.18.0-cp39-cp39-win_amd64.whl", hash = "sha256:ccc292e093771d5baacf5535ac4416306b6b5f15676341cd4d010d8542eace25"}, + {file = "torchvision-0.18.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3e694e54b0548dad99c12af6bf0c8e4f3350137d391dcd19af22a1c5f89322b3"}, + {file = "torchvision-0.18.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:0b3bda0aa5b416eeb547143b8eeaf17720bdba9cf516dc991aacb81811aa96a5"}, + {file = "torchvision-0.18.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:573ff523c739405edb085f65cb592f482d28a30e29b0be4c4ba08040b3ae785f"}, + {file = "torchvision-0.18.1-cp310-cp310-win_amd64.whl", hash = "sha256:ef7bbbc60b38e831a75e547c66ca1784f2ac27100f9e4ddbe9614cef6cbcd942"}, + {file = "torchvision-0.18.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:80b5d794dd0fdba787adc22f1a367a5ead452327686473cb260dd94364bc56a6"}, + {file = "torchvision-0.18.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:9077cf590cdb3a5e8fdf5cdb71797f8c67713f974cf0228ecb17fcd670ab42f9"}, + {file = "torchvision-0.18.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:ceb993a882f1ae7ae373ed39c28d7e3e802205b0e59a7ed84ef4028f0bba8d7f"}, + {file = "torchvision-0.18.1-cp311-cp311-win_amd64.whl", hash = "sha256:52f7436140045dc2239cdc502aa76b2bd8bd676d64244ff154d304aa69852046"}, + {file = "torchvision-0.18.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2be6f0bf7c455c89a51a1dbb6f668d36c6edc479f49ac912d745d10df5715657"}, + {file = "torchvision-0.18.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:f118d887bfde3a948a41d56587525401e5cac1b7db2eaca203324d6ed2b1caca"}, + {file = "torchvision-0.18.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:13d24d904f65e62d66a1e0c41faec630bc193867b8a4a01166769e8a8e8df8e9"}, + {file = "torchvision-0.18.1-cp312-cp312-win_amd64.whl", hash = "sha256:ed6340b69a63a625e512a66127210d412551d9c5f2ad2978130c6a45bf56cd4a"}, + {file = "torchvision-0.18.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b1c3864fa9378c88bce8ad0ef3599f4f25397897ce612e1c245c74b97092f35e"}, + {file = "torchvision-0.18.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:02085a2ffc7461f5c0edb07d6f3455ee1806561f37736b903da820067eea58c7"}, + {file = "torchvision-0.18.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:9726c316a2501df8503e5a5dc46a631afd4c515a958972e5b7f7b9c87d2125c0"}, + {file = "torchvision-0.18.1-cp38-cp38-win_amd64.whl", hash = "sha256:64a2662dbf30db9055d8b201d6e56f312a504e5ccd9d144c57c41622d3c524cb"}, + {file = "torchvision-0.18.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:975b8594c0f5288875408acbb74946eea786c5b008d129c0d045d0ead23742bc"}, + {file = "torchvision-0.18.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:da83c8bbd34d8bee48bfa1d1b40e0844bc3cba10ed825a5a8cbe3ce7b62264cd"}, + {file = "torchvision-0.18.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:54bfcd352abb396d5c9c237d200167c178bd136051b138e1e8ef46ce367c2773"}, + {file = "torchvision-0.18.1-cp39-cp39-win_amd64.whl", hash = "sha256:5c8366a1aeee49e9ea9e64b30d199debdf06b1bd7610a76165eb5d7869c3bde5"}, ] [package.dependencies] numpy = "*" pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0" -torch = "2.3.0" +torch = "2.3.1" [package.extras] scipy = ["scipy"] @@ -3916,17 +3916,17 @@ telegram = ["requests"] [[package]] name = "triton" -version = "2.3.0" +version = "2.3.1" description = "A language and compiler for custom Deep Learning operations" optional = false python-versions = "*" files = [ - {file = "triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce4b8ff70c48e47274c66f269cce8861cf1dc347ceeb7a67414ca151b1822d8"}, - {file = "triton-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c3d9607f85103afdb279938fc1dd2a66e4f5999a58eb48a346bd42738f986dd"}, - {file = "triton-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:218d742e67480d9581bafb73ed598416cc8a56f6316152e5562ee65e33de01c0"}, - {file = "triton-2.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:381ec6b3dac06922d3e4099cfc943ef032893b25415de295e82b1a82b0359d2c"}, - {file = "triton-2.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:038e06a09c06a164fef9c48de3af1e13a63dc1ba3c792871e61a8e79720ea440"}, - {file = "triton-2.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d8f636e0341ac348899a47a057c3daea99ea7db31528a225a3ba4ded28ccc65"}, + {file = "triton-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c84595cbe5e546b1b290d2a58b1494df5a2ef066dd890655e5b8a8a92205c33"}, + {file = "triton-2.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9d64ae33bcb3a7a18081e3a746e8cf87ca8623ca13d2c362413ce7a486f893e"}, + {file = "triton-2.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaf80e8761a9e3498aa92e7bf83a085b31959c61f5e8ac14eedd018df6fccd10"}, + {file = "triton-2.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b13bf35a2b659af7159bf78e92798dc62d877aa991de723937329e2d382f1991"}, + {file = "triton-2.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63381e35ded3304704ea867ffde3b7cfc42c16a55b3062d41e017ef510433d66"}, + {file = "triton-2.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d968264523c7a07911c8fb51b4e0d1b920204dae71491b1fe7b01b62a31e124"}, ] [package.dependencies] @@ -4301,9 +4301,10 @@ dora = ["gym-dora"] pusht = ["gym-pusht"] test = ["pytest", "pytest-cov"] umi = ["imagecodecs"] +video-benchmark = ["scikit-image"] xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "23ddb8dd774a4faf85d08a07dfdf19badb7c370120834b71df4afca254520771" +content-hash = "61f99befbc2250fe59cb54119c3dbd3aa3c1dfe5d3d7790c6f7c4f90fe43112e" diff --git a/pyproject.toml b/pyproject.toml index 0c305218..15ac9c73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ pyav = ">=12.0.5" moviepy = ">=1.0.3" rerun-sdk = ">=0.15.1" deepdiff = ">=7.0.1" +scikit-image = {version = "^0.23.2", optional = true} [tool.poetry.extras] @@ -70,6 +71,7 @@ aloha = ["gym-aloha"] dev = ["pre-commit", "debugpy"] test = ["pytest", "pytest-cov"] umi = ["imagecodecs"] +video_benchmark = ["scikit-image"] [tool.ruff] line-length = 110 diff --git a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/actions.safetensors b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/actions.safetensors new file mode 100644 index 00000000..1529153d --- /dev/null +++ b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/actions.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f4e0e525aeb22ea94b79e26b39a87e6f2da9fbee33e493906aaf2aad9a7c1ef +size 515400 diff --git a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/grad_stats.safetensors new file mode 100644 index 00000000..6a359f4e --- /dev/null +++ b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/grad_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6dc658a1c1616c7d1c211eb8f87cec3d44f7b67d6b3cea7a6ce12b32d74674da +size 31688 diff --git a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/output_dict.safetensors b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/output_dict.safetensors new file mode 100644 index 00000000..09901110 --- /dev/null +++ b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/output_dict.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:03971f92b7907b6b7e6ac207f508666104cd84c26c5276f510c431db604e188b +size 68 diff --git a/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/param_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/param_stats.safetensors new file mode 100644 index 00000000..157c382c --- /dev/null +++ b/tests/data/save_policy_to_safetensors/aloha_act_1000_steps/param_stats.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:01d993c67a9267032fe9fbeff20b4359c209464976ea503040a0a76ae213450a +size 33408 diff --git a/tests/scripts/save_policy_to_safetensors.py b/tests/scripts/save_policy_to_safetensors.py index 961b7cef..5fead55a 100644 --- a/tests/scripts/save_policy_to_safetensors.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -89,8 +89,8 @@ def get_policy_stats(env_name, policy_name, extra_overrides): return output_dict, grad_stats, param_stats, actions -def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides): - env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}" +def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides, file_name_extra): + env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}" if env_policy_dir.exists(): print(f"Overwrite existing safetensors in '{env_policy_dir}':") @@ -108,15 +108,17 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override if __name__ == "__main__": env_policies = [ - ("xarm", "tdmpc", []), - ( - "pusht", - "diffusion", - ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], - ), - ("aloha", "act", ["policy.n_action_steps=10"]), - ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]), - ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]), + # ("xarm", "tdmpc", []), + # ( + # "pusht", + # "diffusion", + # ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], + # ), + ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"), + # ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]), + # ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]), ] - for env, policy, extra_overrides in env_policies: - save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides) + for env, policy, extra_overrides, file_name_extra in env_policies: + save_policy_to_safetensors( + "tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra + ) diff --git a/tests/test_policies.py b/tests/test_policies.py index c099bef0..fdc74751 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -30,6 +30,7 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_ from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import init_hydra_config +from lerobot.scripts.train import make_optimizer_and_scheduler from tests.scripts.save_policy_to_safetensors import get_policy_stats from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel @@ -174,6 +175,33 @@ def test_policy(env_name, policy_name, extra_overrides): env.step(action) +def test_act_backbone_lr(): + """ + Test that the ACT policy can be instantiated with a different learning rate for the backbone. + """ + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, + overrides=[ + "env=aloha", + "policy=act", + f"device={DEVICE}", + "training.lr_backbone=0.001", + "training.lr=0.01", + ], + ) + assert cfg.training.lr == 0.01 + assert cfg.training.lr_backbone == 0.001 + + dataset = make_dataset(cfg) + policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) + optimizer, _ = make_optimizer_and_scheduler(cfg, policy) + assert len(optimizer.param_groups) == 2 + assert optimizer.param_groups[0]["lr"] == cfg.training.lr + assert optimizer.param_groups[1]["lr"] == cfg.training.lr_backbone + assert len(optimizer.param_groups[0]["params"]) == 133 + assert len(optimizer.param_groups[1]["params"]) == 20 + + @pytest.mark.parametrize("policy_name", available_policies) def test_policy_defaults(policy_name: str): """Check that the policy can be instantiated with defaults.""" @@ -287,24 +315,26 @@ def test_normalize(insert_temporal_dim): @pytest.mark.parametrize( - "env_name, policy_name, extra_overrides", + "env_name, policy_name, extra_overrides, file_name_extra", [ - ("xarm", "tdmpc", []), + ("xarm", "tdmpc", [], ""), ( "pusht", "diffusion", ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], + "", ), - ("aloha", "act", ["policy.n_action_steps=10"]), - ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]), - ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]), + ("aloha", "act", ["policy.n_action_steps=10"], ""), + ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"), + ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""), + ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""), ], ) # As artifacts have been generated on an x86_64 kernel, this test won't # pass if it's run on another platform due to floating point errors @require_x86_64_kernel @require_cpu -def test_backward_compatibility(env_name, policy_name, extra_overrides): +def test_backward_compatibility(env_name, policy_name, extra_overrides, file_name_extra): """ NOTE: If this test does not pass, and you have intentionally changed something in the policy: 1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should @@ -316,7 +346,9 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides): 5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state. 6. Remember to stage and commit the resulting changes to `tests/data`. """ - env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}" + env_policy_dir = ( + Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}{file_name_extra}" + ) saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors") saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors") saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")