From c1232a01e2e2872e7250135d6a560f6cfef607b9 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Mon, 21 Oct 2024 00:16:52 +0200 Subject: [PATCH] Add add_frame, empty dataset creation --- lerobot/common/datasets/lerobot_dataset.py | 79 ++++++++++++++++--- lerobot/common/datasets/utils.py | 34 ++++++-- .../common/robot_devices/cameras/opencv.py | 4 + lerobot/common/robot_devices/control_utils.py | 8 +- .../robot_devices/robots/manipulator.py | 7 ++ lerobot/scripts/control_robot.py | 15 ++-- 6 files changed, 114 insertions(+), 33 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 43d8708d..61331c5a 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import logging import os from pathlib import Path @@ -26,15 +25,17 @@ from datasets import load_dataset from huggingface_hub import snapshot_download from lerobot.common.datasets.compute_stats import aggregate_stats +from lerobot.common.datasets.image_writer import ImageWriter from lerobot.common.datasets.utils import ( check_delta_timestamps, check_timestamps_sync, - create_dataset_info, + create_empty_dataset_info, get_delta_indices, get_episode_data_index, get_hub_safe_version, hf_transform_to_torch, load_metadata, + write_json, ) from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision from lerobot.common.robot_devices.robots.utils import Robot @@ -55,6 +56,7 @@ class LeRobotDataset(torch.utils.data.Dataset): tolerance_s: float = 1e-4, download_videos: bool = True, video_backend: str | None = None, + image_writer: ImageWriter | None = None, ): """LeRobotDataset encapsulates 3 main things: - metadata: @@ -156,6 +158,8 @@ class LeRobotDataset(torch.utils.data.Dataset): self.tolerance_s = tolerance_s self.download_videos = download_videos self.video_backend = video_backend if video_backend is not None else "pyav" + self.image_writer = image_writer + self.episode_buffer = {} self.delta_indices = None # Load metadata @@ -296,9 +300,14 @@ class LeRobotDataset(torch.utils.data.Dataset): @property def num_samples(self) -> int: - """Number of samples/frames.""" + """Number of samples/frames in selected episodes.""" return len(self.hf_dataset) + @property + def total_frames(self) -> int: + """Total number of frames saved in this dataset.""" + return self.info["total_frames"] + @property def num_episodes(self) -> int: """Number of episodes selected.""" @@ -423,10 +432,6 @@ class LeRobotDataset(torch.utils.data.Dataset): return item - def write_info(self) -> None: - with open(self.root / "meta/info.json", "w") as f: - json.dump(self.info, f, indent=4, ensure_ascii=False) - def __repr__(self): return ( f"{self.__class__.__name__}(\n" @@ -442,6 +447,49 @@ class LeRobotDataset(torch.utils.data.Dataset): f")" ) + def _create_episode_buffer(self) -> dict: + # TODO(aliberts): Handle resume + return { + "chunk": self.total_chunks, + "episode_index": self.total_episodes, + "size": 0, + "frame_index": [], + "timestamp": [], + "next.done": [], + **{key: [] for key in self.keys}, + } + + def add_frame(self, frame: dict) -> None: + frame_index = self.episode_buffer["size"] + self.episode_buffer["frame_index"].append(frame_index) + self.episode_buffer["timestamp"].append(frame_index / self.fps) + self.episode_buffer["next.done"].append(False) + + # Save all observed modalities except images + for key in self.keys: + self.episode_buffer[key].append(frame[key]) + + self.episode_buffer["size"] += 1 + + if self.image_writer is None: + return + + # Save images + for cam_key in self.camera_keys: + img_path = self.image_writer.get_image_file_path( + episode_index=self.episode_buffer["episode_index"], + image_key=cam_key, + frame_index=frame_index, + return_str=False, + ) + if frame_index == 0: + img_path.parent.mkdir(parents=True, exist_ok=True) + + self.image_writer.async_save_image( + image=frame[cam_key], + file_path=img_path, + ) + @classmethod def create( cls, @@ -450,24 +498,29 @@ class LeRobotDataset(torch.utils.data.Dataset): robot: Robot, root: Path | None = None, tolerance_s: float = 1e-4, + image_writer: ImageWriter | None = None, + use_videos: bool = True, ) -> "LeRobotDataset": """Create a LeRobot Dataset from scratch in order to record data.""" obj = cls.__new__(cls) obj.repo_id = repo_id obj.root = root if root is not None else LEROBOT_HOME / repo_id obj._version = CODEBASE_VERSION + obj.tolerance_s = tolerance_s + obj.image_writer = image_writer - obj.root.mkdir(exist_ok=True, parents=True) - obj.info = create_dataset_info(obj._version, fps, robot) - obj.write_info() - obj.fps = fps - - if not all(cam.fps == fps for cam in robot.cameras): + if not all(cam.fps == fps for cam in robot.cameras.values()): logging.warn( f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset." "In this case, frames from lower fps cameras will be repeated to fill in the blanks" ) + obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos) + write_json(obj.info, obj.root / "meta/info.json") + + # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer + obj.episode_buffer = obj._create_episode_buffer() + # obj.episodes = None # obj.image_transforms = None # obj.delta_timestamps = None diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 90bb35c1..79459882 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -75,6 +75,12 @@ def unflatten_dict(d, sep="/"): return outdict +def write_json(data: dict, fpath: Path) -> None: + fpath.parent.mkdir(exist_ok=True, parents=True) + with open(fpath, "w") as f: + json.dump(data, f, indent=4, ensure_ascii=False) + + def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): """Get a transform function that convert items from Hugging Face dataset (pyarrow) to torch tensors. Importantly, images are converted from PIL, which corresponds to @@ -146,7 +152,16 @@ def load_metadata(local_dir: Path) -> tuple[dict | list]: return info, episode_dicts, stats, tasks -def create_dataset_info(codebase_version: str, fps: int, robot: Robot) -> dict: +def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use_videos: bool = True) -> dict: + shapes = {key: len(names) for key, names in robot.names.items()} + camera_shapes = {} + for key, cam in robot.cameras.items(): + video_key = f"observation.images.{key}" + camera_shapes[video_key] = { + "width": cam.width, + "height": cam.height, + "channels": cam.channels, + } return { "codebase_version": codebase_version, "data_path": DEFAULT_PARQUET_PATH, @@ -159,12 +174,12 @@ def create_dataset_info(codebase_version: str, fps: int, robot: Robot) -> dict: "chunks_size": DEFAULT_CHUNK_SIZE, "fps": fps, "splits": {}, - # "keys": keys, - # "video_keys": video_keys, - # "image_keys": image_keys, - # "shapes": {**sequence_shapes, **video_shapes, **image_shapes}, - # "names": names, - # "videos": {"videos_path": DEFAULT_VIDEO_PATH} if video_keys else None, + "keys": list(robot.names), + "video_keys": list(camera_shapes) if use_videos else [], + "image_keys": [] if use_videos else list(camera_shapes), + "shapes": {**shapes, **camera_shapes}, + "names": robot.names, + "videos": {"videos_path": DEFAULT_VIDEO_PATH} if use_videos else None, } @@ -270,6 +285,7 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic return delta_indices +# TODO(aliberts): remove def load_previous_and_future_frames( item: dict[str, torch.Tensor], hf_dataset: datasets.Dataset, @@ -363,6 +379,7 @@ def load_previous_and_future_frames( return item +# TODO(aliberts): remove def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]: """ Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset. @@ -417,6 +434,7 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc return episode_data_index +# TODO(aliberts): remove def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset: """Reset the `episode_index` of the provided HuggingFace Dataset. @@ -454,7 +472,7 @@ def cycle(iterable): iterator = iter(iterable) -def create_branch(repo_id, *, branch: str, repo_type: str | None = None): +def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None: """Create a branch on a existing Hugging Face repo. Delete the branch if it already exists before creating it. """ diff --git a/lerobot/common/robot_devices/cameras/opencv.py b/lerobot/common/robot_devices/cameras/opencv.py index 2d8b12c9..d284cf55 100644 --- a/lerobot/common/robot_devices/cameras/opencv.py +++ b/lerobot/common/robot_devices/cameras/opencv.py @@ -192,6 +192,7 @@ class OpenCVCameraConfig: width: int | None = None height: int | None = None color_mode: str = "rgb" + channels: int | None = None rotation: int | None = None mock: bool = False @@ -201,6 +202,8 @@ class OpenCVCameraConfig: f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided." ) + self.channels = 3 + if self.rotation not in [-90, None, 90, 180]: raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})") @@ -268,6 +271,7 @@ class OpenCVCamera: self.fps = config.fps self.width = config.width self.height = config.height + self.channels = config.channels self.color_mode = config.color_mode self.mock = config.mock diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 08bcec2e..6a8805dc 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -15,7 +15,8 @@ import torch import tqdm from termcolor import colored -from lerobot.common.datasets.populate_dataset import add_frame, safe_stop_image_writer +from lerobot.common.datasets.image_writer import safe_stop_image_writer +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.policies.factory import make_policy from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.utils import busy_wait @@ -227,7 +228,7 @@ def control_loop( control_time_s=None, teleoperate=False, display_cameras=False, - dataset=None, + dataset: LeRobotDataset | None = None, events=None, policy=None, device=None, @@ -268,7 +269,8 @@ def control_loop( action = {"action": action} if dataset is not None: - add_frame(dataset, observation, action) + frame = {**observation, **action} + dataset.add_frame(frame) if display_cameras and not is_headless(): image_keys = [key for key in observation if "image" in key] diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index 20969c30..6ee2cae7 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -349,6 +349,13 @@ class ManipulatorRobot: self.is_connected = False self.logs = {} + action_names = [f"{arm}_{motor}" for arm, bus in self.leader_arms.items() for motor in bus.motors] + state_names = [f"{arm}_{motor}" for arm, bus in self.follower_arms.items() for motor in bus.motors] + self.names = { + "action": action_names, + "observation.state": state_names, + } + @property def has_camera(self): return len(self.cameras) > 0 diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 425247e6..3d9073b0 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -105,11 +105,11 @@ from pathlib import Path from typing import List # from safetensors.torch import load_file, save_file +from lerobot.common.datasets.image_writer import ImageWriter from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.populate_dataset import ( create_lerobot_dataset, delete_current_episode, - init_dataset, save_current_episode, ) from lerobot.common.robot_devices.control_utils import ( @@ -233,16 +233,12 @@ def record( # Create empty dataset or load existing saved episodes sanity_check_dataset_name(repo_id, policy) - dataset = init_dataset( - repo_id, - root, - force_override, - fps, - video, - write_images=robot.has_camera, + image_writer = ImageWriter( + write_dir=root, num_image_writer_processes=num_image_writer_processes, num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras, ) + dataset = LeRobotDataset.create(repo_id, fps, robot, image_writer=image_writer) if not robot.is_connected: robot.connect() @@ -260,8 +256,9 @@ def record( if has_method(robot, "teleop_safety_stop"): robot.teleop_safety_stop() + recorded_episodes = 0 while True: - if dataset["num_episodes"] >= num_episodes: + if recorded_episodes >= num_episodes: break episode_index = dataset["num_episodes"]