From e991a310614a533284abc64d0f2e6f49682d8809 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Tue, 22 Oct 2024 00:19:25 +0200 Subject: [PATCH] Improve consistency between __init__() and create(), WIP on consolidate --- lerobot/common/datasets/lerobot_dataset.py | 69 ++++++++++++++-------- lerobot/scripts/control_robot.py | 20 ++++--- 2 files changed, 55 insertions(+), 34 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 6d68946e..ffbcf0fb 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import os import shutil @@ -39,6 +40,7 @@ from lerobot.common.datasets.utils import ( get_hub_safe_version, hf_transform_to_torch, load_metadata, + unflatten_dict, write_json, ) from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision @@ -163,9 +165,9 @@ class LeRobotDataset(torch.utils.data.Dataset): self.download_videos = download_videos self.video_backend = video_backend if video_backend is not None else "pyav" self.image_writer = image_writer - self.episode_buffer = {} - self.consolidated = True self.delta_indices = None + self.consolidated = True + self.episode_buffer = {} # Load metadata self.root.mkdir(exist_ok=True, parents=True) @@ -501,17 +503,12 @@ class LeRobotDataset(torch.utils.data.Dataset): def __repr__(self): return ( - f"{self.__class__.__name__}(\n" + f"{self.__class__.__name__}\n" f" Repository ID: '{self.repo_id}',\n" - f" Number of Samples: {self.num_samples},\n" - f" Number of Episodes: {self.num_episodes},\n" - f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" - f" Recorded Frames per Second: {self.fps},\n" - f" Camera Keys: {self.camera_keys},\n" - f" Video Frame Keys: {self.camera_keys if self.video else 'N/A'},\n" - f" Transformations: {self.image_transforms},\n" - f" Codebase Version: {self.info.get('codebase_version', '< v1.6')},\n" - f")" + f" Selected episodes: {self.episodes},\n" + f" Number of selected episodes: {self.num_episodes},\n" + f" Number of selected samples: {self.num_samples},\n" + f"\n{json.dumps(self.info, indent=4)}\n" ) def _create_episode_buffer(self, episode_index: int | None = None) -> dict: @@ -563,12 +560,16 @@ class LeRobotDataset(torch.utils.data.Dataset): disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to the hub. - Use encode_videos if you want to encode videos during the saving of each episode. Otherwise, - you can do it later during dataset.consolidate(). This is to give more flexibility on when to spend + Use 'encode_videos' if you want to encode videos during the saving of each episode. Otherwise, + you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend time for video encoding. """ episode_length = self.episode_buffer.pop("size") episode_index = self.episode_buffer["episode_index"] + if episode_index != self.total_episodes: + # TODO(aliberts): Add option to use existing episode_index + raise NotImplementedError() + task_index = self.get_task_index(task) self.episode_buffer["next.done"][-1] = True @@ -641,12 +642,30 @@ class LeRobotDataset(torch.utils.data.Dataset): # Reset the buffer self.episode_buffer = self._create_episode_buffer() + def _update_data_file_names(self) -> None: + # TODO(aliberts): remove the need for this hack by removing total_episodes part in data file names. + # Must first investigate if this doesn't break hub/datasets features like viewer etc. + for ep_idx in range(self.total_episodes): + ep_chunk = self.get_episode_chunk(ep_idx) + current_file_name = self.data_path.replace("{total_episodes:05d}", "*") + current_file_name = current_file_name.format(episode_chunk=ep_chunk, episode_index=ep_idx) + current_file_name = list(self.root.glob(current_file_name))[0] + updated_file_name = self.get_data_file_path(ep_idx) + current_file_name.rename(updated_file_name) + def consolidate(self, run_compute_stats: bool = True) -> None: + self._update_data_file_names() if run_compute_stats: logging.info("Computing dataset statistics") self.hf_dataset = self.load_hf_dataset() self.stats = compute_stats(self) - write_json() + serialized_stats = {key: value.tolist() for key, value in self.stats.items()} + serialized_stats = unflatten_dict(serialized_stats) + write_json(serialized_stats, self.root / "meta/stats.json") + else: + logging.warning("Skipping computation of the dataset statistics.") + + self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts) pass # TODO # Sanity checks: # - [ ] shapes @@ -666,6 +685,7 @@ class LeRobotDataset(torch.utils.data.Dataset): tolerance_s: float = 1e-4, image_writer: ImageWriter | None = None, use_videos: bool = True, + video_backend: str | None = None, ) -> "LeRobotDataset": """Create a LeRobot Dataset from scratch in order to record data.""" obj = cls.__new__(cls) @@ -674,15 +694,14 @@ class LeRobotDataset(torch.utils.data.Dataset): obj._version = CODEBASE_VERSION obj.tolerance_s = tolerance_s obj.image_writer = image_writer - obj.hf_dataset = None if not all(cam.fps == fps for cam in robot.cameras.values()): - logging.warn( + logging.warning( f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset." "In this case, frames from lower fps cameras will be repeated to fill in the blanks" ) - obj.tasks = {} + obj.tasks, obj.stats, obj.episode_dicts = {}, {}, [] obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos) write_json(obj.info, obj.root / "meta/info.json") @@ -694,14 +713,12 @@ class LeRobotDataset(torch.utils.data.Dataset): # In order to be able to push the dataset to the hub, it needs to be consolidation first. obj.consolidated = True - # obj.episodes = None - # obj.image_transforms = None - # obj.delta_timestamps = None - # obj.episode_data_index = episode_data_index - # obj.stats = stats - # obj.info = info if info is not None else {} - # obj.videos_dir = videos_dir - # obj.video_backend = video_backend if video_backend is not None else "pyav" + obj.episodes = None + obj.hf_dataset = None + obj.image_transforms = None + obj.delta_timestamps = None + obj.episode_data_index = None + obj.video_backend = video_backend if video_backend is not None else "pyav" return obj diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 86233251..62d6760b 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -107,9 +107,6 @@ from typing import List # from safetensors.torch import load_file, save_file from lerobot.common.datasets.image_writer import ImageWriter from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.datasets.populate_dataset import ( - create_lerobot_dataset, -) from lerobot.common.robot_devices.control_utils import ( control_loop, has_method, @@ -210,7 +207,7 @@ def record( force_override=False, display_cameras=True, play_sounds=True, -): +) -> LeRobotDataset: # TODO(rcadene): Add option to record logs listener = None events = None @@ -242,7 +239,7 @@ def record( num_processes=num_image_writer_processes, num_threads=num_image_writer_threads_per_camera * robot.num_cameras, ) - dataset = LeRobotDataset.create(repo_id, fps, robot, image_writer=image_writer) + dataset = LeRobotDataset.create(repo_id, fps, robot, root=root, image_writer=image_writer) if not robot.is_connected: robot.connect() @@ -301,8 +298,8 @@ def record( dataset.delete_episode() continue - # Increment by one dataset["current_episode_index"] dataset.add_episode(task) + recorded_episodes += 1 if events["stop_recording"]: break @@ -310,10 +307,17 @@ def record( log_say("Stop recording", play_sounds, blocking=True) stop_recording(robot, listener, display_cameras) - lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds) + logging.info("Waiting for image writer to terminate...") + dataset.image_writer.stop() + + dataset.consolidate(run_compute_stats) + + # lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds) + if push_to_hub: + dataset.push_to_repo() log_say("Exiting", play_sounds) - return lerobot_dataset + return dataset @safe_disconnect