Remove total_episodes from default parquet path
This commit is contained in:
parent
237a484be0
commit
c72dc23c43
|
@ -296,13 +296,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
def get_data_file_path(self, ep_index: int) -> Path:
|
def get_data_file_path(self, ep_index: int) -> Path:
|
||||||
ep_chunk = self.get_episode_chunk(ep_index)
|
ep_chunk = self.get_episode_chunk(ep_index)
|
||||||
return self.data_path.format(
|
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
|
||||||
episode_chunk=ep_chunk, episode_index=ep_index, total_episodes=self.total_episodes
|
return Path(fpath)
|
||||||
)
|
|
||||||
|
|
||||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||||
ep_chunk = self.get_episode_chunk(ep_index)
|
ep_chunk = self.get_episode_chunk(ep_index)
|
||||||
return self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||||
|
return Path(fpath)
|
||||||
|
|
||||||
def get_episode_chunk(self, ep_index: int) -> int:
|
def get_episode_chunk(self, ep_index: int) -> int:
|
||||||
ep_chunk = ep_index // self.chunks_size
|
ep_chunk = ep_index // self.chunks_size
|
||||||
|
@ -678,17 +678,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# Reset the buffer
|
# Reset the buffer
|
||||||
self.episode_buffer = self._create_episode_buffer()
|
self.episode_buffer = self._create_episode_buffer()
|
||||||
|
|
||||||
def _update_data_file_names(self) -> None:
|
|
||||||
# TODO(aliberts): remove the need for this hack by removing total_episodes part in data file names.
|
|
||||||
# Must first investigate if this doesn't break hub/datasets features like viewer etc.
|
|
||||||
for ep_idx in range(self.total_episodes):
|
|
||||||
ep_chunk = self.get_episode_chunk(ep_idx)
|
|
||||||
current_file_name = self.data_path.replace("{total_episodes:05d}", "*")
|
|
||||||
current_file_name = current_file_name.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
|
||||||
current_file_name = list(self.root.glob(current_file_name))[0]
|
|
||||||
updated_file_name = self.root / self.get_data_file_path(ep_idx)
|
|
||||||
current_file_name.rename(updated_file_name)
|
|
||||||
|
|
||||||
def _remove_image_writer(self) -> None:
|
def _remove_image_writer(self) -> None:
|
||||||
if self.image_writer is not None:
|
if self.image_writer is not None:
|
||||||
self.image_writer = None
|
self.image_writer = None
|
||||||
|
@ -710,7 +699,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
shutil.rmtree(tmp_imgs_dir)
|
shutil.rmtree(tmp_imgs_dir)
|
||||||
|
|
||||||
def consolidate(self, run_compute_stats: bool = True) -> None:
|
def consolidate(self, run_compute_stats: bool = True) -> None:
|
||||||
self._update_data_file_names()
|
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||||
|
|
|
@ -37,9 +37,8 @@ STATS_PATH = "meta/stats.json"
|
||||||
TASKS_PATH = "meta/tasks.jsonl"
|
TASKS_PATH = "meta/tasks.jsonl"
|
||||||
|
|
||||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||||
DEFAULT_PARQUET_PATH = (
|
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||||
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
|
||||||
)
|
|
||||||
DATASET_CARD_TEMPLATE = """
|
DATASET_CARD_TEMPLATE = """
|
||||||
---
|
---
|
||||||
# Metadata will go there
|
# Metadata will go there
|
||||||
|
@ -88,6 +87,7 @@ def write_json(data: dict, fpath: Path) -> None:
|
||||||
|
|
||||||
|
|
||||||
def append_jsonl(data: dict, fpath: Path) -> None:
|
def append_jsonl(data: dict, fpath: Path) -> None:
|
||||||
|
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
with jsonlines.open(fpath, "a") as writer:
|
with jsonlines.open(fpath, "a") as writer:
|
||||||
writer.write(data)
|
writer.write(data)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue