diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py index c87e342b..7bdefc64 100644 --- a/lerobot/common/datasets/image_writer.py +++ b/lerobot/common/datasets/image_writer.py @@ -107,6 +107,12 @@ class ImageWriter: ) return str(self.dir / fpath) if return_str else self.dir / fpath + def get_episode_dir(self, episode_index: int, image_key: str, return_str: bool = True) -> str | Path: + dir_path = self.get_image_file_path( + episode_index=episode_index, image_key=image_key, frame_index=0, return_str=False + ).parent + return str(dir_path) if return_str else dir_path + def stop(self, timeout=20) -> None: """Stop the image writer, waiting for all processes or threads to finish.""" if self.type == "threads": diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 53b3c4af..6d68946e 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -15,6 +15,7 @@ # limitations under the License. import logging import os +import shutil from pathlib import Path from typing import Callable @@ -25,7 +26,7 @@ import torch.utils from datasets import load_dataset from huggingface_hub import snapshot_download, upload_folder -from lerobot.common.datasets.compute_stats import aggregate_stats +from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats from lerobot.common.datasets.image_writer import ImageWriter from lerobot.common.datasets.utils import ( append_jsonl, @@ -630,9 +631,22 @@ class LeRobotDataset(torch.utils.data.Dataset): append_jsonl(episode_dict, self.root / "meta/episodes.jsonl") def delete_episode(self) -> None: - pass # TODO + episode_index = self.episode_buffer["episode_index"] + if self.image_writer is not None: + for cam_key in self.camera_keys: + cam_dir = self.image_writer.get_episode_dir(episode_index, cam_key) + if cam_dir.is_dir(): + shutil.rmtree(cam_dir) - def consolidate(self) -> None: + # Reset the buffer + self.episode_buffer = self._create_episode_buffer() + + def consolidate(self, run_compute_stats: bool = True) -> None: + if run_compute_stats: + logging.info("Computing dataset statistics") + self.hf_dataset = self.load_hf_dataset() + self.stats = compute_stats(self) + write_json() pass # TODO # Sanity checks: # - [ ] shapes