From 237a484be0704160ca32ade58eb07b2eed0db5fb Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Tue, 22 Oct 2024 22:46:34 +0200 Subject: [PATCH] Fix paths & add add_frame doc --- lerobot/common/datasets/image_writer.py | 13 ++++----- lerobot/common/datasets/lerobot_dataset.py | 34 +++++++++++----------- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py index b86a7cdf..09f803e2 100644 --- a/lerobot/common/datasets/image_writer.py +++ b/lerobot/common/datasets/image_writer.py @@ -99,19 +99,16 @@ class ImageWriter: img = Image.fromarray(image.numpy()) img.save(str(file_path), quality=100) - def get_image_file_path( - self, episode_index: int, image_key: str, frame_index: int, return_str: bool = True - ) -> str | Path: + def get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: fpath = self.image_path.format( image_key=image_key, episode_index=episode_index, frame_index=frame_index ) - return str(self.dir / fpath) if return_str else self.dir / fpath + return self.dir / fpath - def get_episode_dir(self, episode_index: int, image_key: str, return_str: bool = True) -> str | Path: - dir_path = self.get_image_file_path( - episode_index=episode_index, image_key=image_key, frame_index=0, return_str=False + def get_episode_dir(self, episode_index: int, image_key: str) -> Path: + return self.get_image_file_path( + episode_index=episode_index, image_key=image_key, frame_index=0 ).parent - return str(dir_path) if return_str else dir_path def stop(self, timeout=20) -> None: """Stop the image writer, waiting for all processes or threads to finish.""" diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index ad5a37cf..1f01d9f0 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -271,10 +271,10 @@ class LeRobotDataset(torch.utils.data.Dataset): files = None ignore_patterns = None if download_videos else "videos/" if self.episodes is not None: - files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes] + files = [str(self.get_data_file_path(ep_idx)) for ep_idx in self.episodes] if len(self.video_keys) > 0 and download_videos: video_files = [ - self.get_video_file_path(ep_idx, vid_key) + str(self.get_video_file_path(ep_idx, vid_key)) for vid_key in self.video_keys for ep_idx in self.episodes ] @@ -288,23 +288,21 @@ class LeRobotDataset(torch.utils.data.Dataset): path = str(self.root / "data") hf_dataset = load_dataset("parquet", data_dir=path, split="train") else: - files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes] + files = [str(self.root / self.get_data_file_path(ep_idx)) for ep_idx in self.episodes] hf_dataset = load_dataset("parquet", data_files=files, split="train") hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset - def get_data_file_path(self, ep_index: int, return_str: bool = True) -> str | Path: + def get_data_file_path(self, ep_index: int) -> Path: ep_chunk = self.get_episode_chunk(ep_index) - fpath = self.data_path.format( + return self.data_path.format( episode_chunk=ep_chunk, episode_index=ep_index, total_episodes=self.total_episodes ) - return str(self.root / fpath) if return_str else self.root / fpath - def get_video_file_path(self, ep_index: int, vid_key: str, return_str: bool = True) -> str | Path: + def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: ep_chunk = self.get_episode_chunk(ep_index) - fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) - return str(self.root / fpath) if return_str else self.root / fpath + return self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) def get_episode_chunk(self, ep_index: int) -> int: ep_chunk = ep_index // self.chunks_size @@ -554,6 +552,11 @@ class LeRobotDataset(torch.utils.data.Dataset): } def add_frame(self, frame: dict) -> None: + """ + This function only adds the frame to the episode_buffer. Apart from images — which are written in a + temporary directory — nothing is written to disk. To save those frames, the 'add_episode()' method + then needs to be called. + """ frame_index = self.episode_buffer["size"] self.episode_buffer["frame_index"].append(frame_index) self.episode_buffer["timestamp"].append(frame_index / self.fps) @@ -571,10 +574,7 @@ class LeRobotDataset(torch.utils.data.Dataset): # Save images for cam_key in self.camera_keys: img_path = self.image_writer.get_image_file_path( - episode_index=self.episode_buffer["episode_index"], - image_key=cam_key, - frame_index=frame_index, - return_str=False, + episode_index=self.episode_buffer["episode_index"], image_key=cam_key, frame_index=frame_index ) if frame_index == 0: img_path.parent.mkdir(parents=True, exist_ok=True) @@ -632,7 +632,7 @@ class LeRobotDataset(torch.utils.data.Dataset): features = self.features ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=features, split="train") ep_table = ep_dataset._data.table - ep_data_path = self.get_data_file_path(ep_index=episode_index, return_str=False) + ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index) ep_data_path.parent.mkdir(parents=True, exist_ok=True) pq.write_table(ep_table, ep_data_path) @@ -671,7 +671,7 @@ class LeRobotDataset(torch.utils.data.Dataset): episode_index = self.episode_buffer["episode_index"] if self.image_writer is not None: for cam_key in self.camera_keys: - img_dir = self.image_writer.get_episode_dir(episode_index, cam_key, return_str=False) + img_dir = self.image_writer.get_episode_dir(episode_index, cam_key) if img_dir.is_dir(): shutil.rmtree(img_dir) @@ -686,7 +686,7 @@ class LeRobotDataset(torch.utils.data.Dataset): current_file_name = self.data_path.replace("{total_episodes:05d}", "*") current_file_name = current_file_name.format(episode_chunk=ep_chunk, episode_index=ep_idx) current_file_name = list(self.root.glob(current_file_name))[0] - updated_file_name = self.get_data_file_path(ep_idx) + updated_file_name = self.root / self.get_data_file_path(ep_idx) current_file_name.rename(updated_file_name) def _remove_image_writer(self) -> None: @@ -700,7 +700,7 @@ class LeRobotDataset(torch.utils.data.Dataset): # TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need # to call self.image_writer here tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key) - video_path = self.get_video_file_path(episode_index, key, return_str=False) + video_path = self.root / self.get_video_file_path(episode_index, key) if video_path.is_file(): # Skip if video is already encoded. Could be the case when resuming data recording. continue