Fix paths & add add_frame doc
This commit is contained in:
parent
6c2cb6e107
commit
237a484be0
|
@ -99,19 +99,16 @@ class ImageWriter:
|
|||
img = Image.fromarray(image.numpy())
|
||||
img.save(str(file_path), quality=100)
|
||||
|
||||
def get_image_file_path(
|
||||
self, episode_index: int, image_key: str, frame_index: int, return_str: bool = True
|
||||
) -> str | Path:
|
||||
def get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = self.image_path.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return str(self.dir / fpath) if return_str else self.dir / fpath
|
||||
return self.dir / fpath
|
||||
|
||||
def get_episode_dir(self, episode_index: int, image_key: str, return_str: bool = True) -> str | Path:
|
||||
dir_path = self.get_image_file_path(
|
||||
episode_index=episode_index, image_key=image_key, frame_index=0, return_str=False
|
||||
def get_episode_dir(self, episode_index: int, image_key: str) -> Path:
|
||||
return self.get_image_file_path(
|
||||
episode_index=episode_index, image_key=image_key, frame_index=0
|
||||
).parent
|
||||
return str(dir_path) if return_str else dir_path
|
||||
|
||||
def stop(self, timeout=20) -> None:
|
||||
"""Stop the image writer, waiting for all processes or threads to finish."""
|
||||
|
|
|
@ -271,10 +271,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
files = None
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
if self.episodes is not None:
|
||||
files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes]
|
||||
files = [str(self.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
if len(self.video_keys) > 0 and download_videos:
|
||||
video_files = [
|
||||
self.get_video_file_path(ep_idx, vid_key)
|
||||
str(self.get_video_file_path(ep_idx, vid_key))
|
||||
for vid_key in self.video_keys
|
||||
for ep_idx in self.episodes
|
||||
]
|
||||
|
@ -288,23 +288,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
path = str(self.root / "data")
|
||||
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
||||
else:
|
||||
files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes]
|
||||
files = [str(self.root / self.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
||||
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
def get_data_file_path(self, ep_index: int, return_str: bool = True) -> str | Path:
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.data_path.format(
|
||||
return self.data_path.format(
|
||||
episode_chunk=ep_chunk, episode_index=ep_index, total_episodes=self.total_episodes
|
||||
)
|
||||
return str(self.root / fpath) if return_str else self.root / fpath
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str, return_str: bool = True) -> str | Path:
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
return str(self.root / fpath) if return_str else self.root / fpath
|
||||
return self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
|
||||
def get_episode_chunk(self, ep_index: int) -> int:
|
||||
ep_chunk = ep_index // self.chunks_size
|
||||
|
@ -554,6 +552,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
}
|
||||
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
"""
|
||||
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
|
||||
temporary directory — nothing is written to disk. To save those frames, the 'add_episode()' method
|
||||
then needs to be called.
|
||||
"""
|
||||
frame_index = self.episode_buffer["size"]
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(frame_index / self.fps)
|
||||
|
@ -571,10 +574,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# Save images
|
||||
for cam_key in self.camera_keys:
|
||||
img_path = self.image_writer.get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"],
|
||||
image_key=cam_key,
|
||||
frame_index=frame_index,
|
||||
return_str=False,
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=cam_key, frame_index=frame_index
|
||||
)
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -632,7 +632,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
features = self.features
|
||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=features, split="train")
|
||||
ep_table = ep_dataset._data.table
|
||||
ep_data_path = self.get_data_file_path(ep_index=episode_index, return_str=False)
|
||||
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(ep_table, ep_data_path)
|
||||
|
||||
|
@ -671,7 +671,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
episode_index = self.episode_buffer["episode_index"]
|
||||
if self.image_writer is not None:
|
||||
for cam_key in self.camera_keys:
|
||||
img_dir = self.image_writer.get_episode_dir(episode_index, cam_key, return_str=False)
|
||||
img_dir = self.image_writer.get_episode_dir(episode_index, cam_key)
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
|
@ -686,7 +686,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
current_file_name = self.data_path.replace("{total_episodes:05d}", "*")
|
||||
current_file_name = current_file_name.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
current_file_name = list(self.root.glob(current_file_name))[0]
|
||||
updated_file_name = self.get_data_file_path(ep_idx)
|
||||
updated_file_name = self.root / self.get_data_file_path(ep_idx)
|
||||
current_file_name.rename(updated_file_name)
|
||||
|
||||
def _remove_image_writer(self) -> None:
|
||||
|
@ -700,7 +700,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
|
||||
# to call self.image_writer here
|
||||
tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key)
|
||||
video_path = self.get_video_file_path(episode_index, key, return_str=False)
|
||||
video_path = self.root / self.get_video_file_path(episode_index, key)
|
||||
if video_path.is_file():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
|
|
Loading…
Reference in New Issue