Remove total_episodes from default parquet path
This commit is contained in:
parent
237a484be0
commit
c72dc23c43
|
@ -296,13 +296,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
return self.data_path.format(
|
||||
episode_chunk=ep_chunk, episode_index=ep_index, total_episodes=self.total_episodes
|
||||
)
|
||||
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
|
||||
return Path(fpath)
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
return self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
return Path(fpath)
|
||||
|
||||
def get_episode_chunk(self, ep_index: int) -> int:
|
||||
ep_chunk = ep_index // self.chunks_size
|
||||
|
@ -678,17 +678,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
def _update_data_file_names(self) -> None:
|
||||
# TODO(aliberts): remove the need for this hack by removing total_episodes part in data file names.
|
||||
# Must first investigate if this doesn't break hub/datasets features like viewer etc.
|
||||
for ep_idx in range(self.total_episodes):
|
||||
ep_chunk = self.get_episode_chunk(ep_idx)
|
||||
current_file_name = self.data_path.replace("{total_episodes:05d}", "*")
|
||||
current_file_name = current_file_name.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
current_file_name = list(self.root.glob(current_file_name))[0]
|
||||
updated_file_name = self.root / self.get_data_file_path(ep_idx)
|
||||
current_file_name.rename(updated_file_name)
|
||||
|
||||
def _remove_image_writer(self) -> None:
|
||||
if self.image_writer is not None:
|
||||
self.image_writer = None
|
||||
|
@ -710,7 +699,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True) -> None:
|
||||
self._update_data_file_names()
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
|
|
@ -37,9 +37,8 @@ STATS_PATH = "meta/stats.json"
|
|||
TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = (
|
||||
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
||||
)
|
||||
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
# Metadata will go there
|
||||
|
@ -88,6 +87,7 @@ def write_json(data: dict, fpath: Path) -> None:
|
|||
|
||||
|
||||
def append_jsonl(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with jsonlines.open(fpath, "a") as writer:
|
||||
writer.write(data)
|
||||
|
||||
|
|
Loading…
Reference in New Issue