Remove total_episodes from default parquet path

This commit is contained in:
Simon Alibert 2024-10-23 00:03:30 +02:00
parent 237a484be0
commit c72dc23c43
2 changed files with 7 additions and 19 deletions

View File

@ -296,13 +296,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
def get_data_file_path(self, ep_index: int) -> Path:
ep_chunk = self.get_episode_chunk(ep_index)
return self.data_path.format(
episode_chunk=ep_chunk, episode_index=ep_index, total_episodes=self.total_episodes
)
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
return Path(fpath)
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
ep_chunk = self.get_episode_chunk(ep_index)
return self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
return Path(fpath)
def get_episode_chunk(self, ep_index: int) -> int:
ep_chunk = ep_index // self.chunks_size
@ -678,17 +678,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Reset the buffer
self.episode_buffer = self._create_episode_buffer()
def _update_data_file_names(self) -> None:
# TODO(aliberts): remove the need for this hack by removing total_episodes part in data file names.
# Must first investigate if this doesn't break hub/datasets features like viewer etc.
for ep_idx in range(self.total_episodes):
ep_chunk = self.get_episode_chunk(ep_idx)
current_file_name = self.data_path.replace("{total_episodes:05d}", "*")
current_file_name = current_file_name.format(episode_chunk=ep_chunk, episode_index=ep_idx)
current_file_name = list(self.root.glob(current_file_name))[0]
updated_file_name = self.root / self.get_data_file_path(ep_idx)
current_file_name.rename(updated_file_name)
def _remove_image_writer(self) -> None:
if self.image_writer is not None:
self.image_writer = None
@ -710,7 +699,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
shutil.rmtree(tmp_imgs_dir)
def consolidate(self, run_compute_stats: bool = True) -> None:
self._update_data_file_names()
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)

View File

@ -37,9 +37,8 @@ STATS_PATH = "meta/stats.json"
TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
DEFAULT_PARQUET_PATH = (
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
)
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
DATASET_CARD_TEMPLATE = """
---
# Metadata will go there
@ -88,6 +87,7 @@ def write_json(data: dict, fpath: Path) -> None:
def append_jsonl(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "a") as writer:
writer.write(data)