Fix paths & add add_frame doc

This commit is contained in:
Simon Alibert 2024-10-22 22:46:34 +02:00
parent 6c2cb6e107
commit 237a484be0
2 changed files with 22 additions and 25 deletions

View File

@ -99,19 +99,16 @@ class ImageWriter:
img = Image.fromarray(image.numpy()) img = Image.fromarray(image.numpy())
img.save(str(file_path), quality=100) img.save(str(file_path), quality=100)
def get_image_file_path( def get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
self, episode_index: int, image_key: str, frame_index: int, return_str: bool = True
) -> str | Path:
fpath = self.image_path.format( fpath = self.image_path.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index image_key=image_key, episode_index=episode_index, frame_index=frame_index
) )
return str(self.dir / fpath) if return_str else self.dir / fpath return self.dir / fpath
def get_episode_dir(self, episode_index: int, image_key: str, return_str: bool = True) -> str | Path: def get_episode_dir(self, episode_index: int, image_key: str) -> Path:
dir_path = self.get_image_file_path( return self.get_image_file_path(
episode_index=episode_index, image_key=image_key, frame_index=0, return_str=False episode_index=episode_index, image_key=image_key, frame_index=0
).parent ).parent
return str(dir_path) if return_str else dir_path
def stop(self, timeout=20) -> None: def stop(self, timeout=20) -> None:
"""Stop the image writer, waiting for all processes or threads to finish.""" """Stop the image writer, waiting for all processes or threads to finish."""

View File

@ -271,10 +271,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
files = None files = None
ignore_patterns = None if download_videos else "videos/" ignore_patterns = None if download_videos else "videos/"
if self.episodes is not None: if self.episodes is not None:
files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes] files = [str(self.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
if len(self.video_keys) > 0 and download_videos: if len(self.video_keys) > 0 and download_videos:
video_files = [ video_files = [
self.get_video_file_path(ep_idx, vid_key) str(self.get_video_file_path(ep_idx, vid_key))
for vid_key in self.video_keys for vid_key in self.video_keys
for ep_idx in self.episodes for ep_idx in self.episodes
] ]
@ -288,23 +288,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
path = str(self.root / "data") path = str(self.root / "data")
hf_dataset = load_dataset("parquet", data_dir=path, split="train") hf_dataset = load_dataset("parquet", data_dir=path, split="train")
else: else:
files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes] files = [str(self.root / self.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
hf_dataset = load_dataset("parquet", data_files=files, split="train") hf_dataset = load_dataset("parquet", data_files=files, split="train")
hf_dataset.set_transform(hf_transform_to_torch) hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset return hf_dataset
def get_data_file_path(self, ep_index: int, return_str: bool = True) -> str | Path: def get_data_file_path(self, ep_index: int) -> Path:
ep_chunk = self.get_episode_chunk(ep_index) ep_chunk = self.get_episode_chunk(ep_index)
fpath = self.data_path.format( return self.data_path.format(
episode_chunk=ep_chunk, episode_index=ep_index, total_episodes=self.total_episodes episode_chunk=ep_chunk, episode_index=ep_index, total_episodes=self.total_episodes
) )
return str(self.root / fpath) if return_str else self.root / fpath
def get_video_file_path(self, ep_index: int, vid_key: str, return_str: bool = True) -> str | Path: def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
ep_chunk = self.get_episode_chunk(ep_index) ep_chunk = self.get_episode_chunk(ep_index)
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) return self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
return str(self.root / fpath) if return_str else self.root / fpath
def get_episode_chunk(self, ep_index: int) -> int: def get_episode_chunk(self, ep_index: int) -> int:
ep_chunk = ep_index // self.chunks_size ep_chunk = ep_index // self.chunks_size
@ -554,6 +552,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
} }
def add_frame(self, frame: dict) -> None: def add_frame(self, frame: dict) -> None:
"""
This function only adds the frame to the episode_buffer. Apart from images which are written in a
temporary directory nothing is written to disk. To save those frames, the 'add_episode()' method
then needs to be called.
"""
frame_index = self.episode_buffer["size"] frame_index = self.episode_buffer["size"]
self.episode_buffer["frame_index"].append(frame_index) self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(frame_index / self.fps) self.episode_buffer["timestamp"].append(frame_index / self.fps)
@ -571,10 +574,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Save images # Save images
for cam_key in self.camera_keys: for cam_key in self.camera_keys:
img_path = self.image_writer.get_image_file_path( img_path = self.image_writer.get_image_file_path(
episode_index=self.episode_buffer["episode_index"], episode_index=self.episode_buffer["episode_index"], image_key=cam_key, frame_index=frame_index
image_key=cam_key,
frame_index=frame_index,
return_str=False,
) )
if frame_index == 0: if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True) img_path.parent.mkdir(parents=True, exist_ok=True)
@ -632,7 +632,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
features = self.features features = self.features
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=features, split="train") ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=features, split="train")
ep_table = ep_dataset._data.table ep_table = ep_dataset._data.table
ep_data_path = self.get_data_file_path(ep_index=episode_index, return_str=False) ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
ep_data_path.parent.mkdir(parents=True, exist_ok=True) ep_data_path.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(ep_table, ep_data_path) pq.write_table(ep_table, ep_data_path)
@ -671,7 +671,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_index = self.episode_buffer["episode_index"] episode_index = self.episode_buffer["episode_index"]
if self.image_writer is not None: if self.image_writer is not None:
for cam_key in self.camera_keys: for cam_key in self.camera_keys:
img_dir = self.image_writer.get_episode_dir(episode_index, cam_key, return_str=False) img_dir = self.image_writer.get_episode_dir(episode_index, cam_key)
if img_dir.is_dir(): if img_dir.is_dir():
shutil.rmtree(img_dir) shutil.rmtree(img_dir)
@ -686,7 +686,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
current_file_name = self.data_path.replace("{total_episodes:05d}", "*") current_file_name = self.data_path.replace("{total_episodes:05d}", "*")
current_file_name = current_file_name.format(episode_chunk=ep_chunk, episode_index=ep_idx) current_file_name = current_file_name.format(episode_chunk=ep_chunk, episode_index=ep_idx)
current_file_name = list(self.root.glob(current_file_name))[0] current_file_name = list(self.root.glob(current_file_name))[0]
updated_file_name = self.get_data_file_path(ep_idx) updated_file_name = self.root / self.get_data_file_path(ep_idx)
current_file_name.rename(updated_file_name) current_file_name.rename(updated_file_name)
def _remove_image_writer(self) -> None: def _remove_image_writer(self) -> None:
@ -700,7 +700,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need # TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
# to call self.image_writer here # to call self.image_writer here
tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key) tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key)
video_path = self.get_video_file_path(episode_index, key, return_str=False) video_path = self.root / self.get_video_file_path(episode_index, key)
if video_path.is_file(): if video_path.is_file():
# Skip if video is already encoded. Could be the case when resuming data recording. # Skip if video is already encoded. Could be the case when resuming data recording.
continue continue