From c4c0a43de76c61c118bf07a4b52b605abf883fd3 Mon Sep 17 00:00:00 2001
From: Simon Alibert <simon.alibert@huggingface.co>
Date: Mon, 21 Oct 2024 20:10:13 +0200
Subject: [PATCH] add delete_episode, WIP on consolidate

---
 lerobot/common/datasets/image_writer.py    |  6 ++++++
 lerobot/common/datasets/lerobot_dataset.py | 20 +++++++++++++++++---
 2 files changed, 23 insertions(+), 3 deletions(-)

diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py
index c87e342b..7bdefc64 100644
--- a/lerobot/common/datasets/image_writer.py
+++ b/lerobot/common/datasets/image_writer.py
@@ -107,6 +107,12 @@ class ImageWriter:
         )
         return str(self.dir / fpath) if return_str else self.dir / fpath
 
+    def get_episode_dir(self, episode_index: int, image_key: str, return_str: bool = True) -> str | Path:
+        dir_path = self.get_image_file_path(
+            episode_index=episode_index, image_key=image_key, frame_index=0, return_str=False
+        ).parent
+        return str(dir_path) if return_str else dir_path
+
     def stop(self, timeout=20) -> None:
         """Stop the image writer, waiting for all processes or threads to finish."""
         if self.type == "threads":
diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py
index 53b3c4af..6d68946e 100644
--- a/lerobot/common/datasets/lerobot_dataset.py
+++ b/lerobot/common/datasets/lerobot_dataset.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 import logging
 import os
+import shutil
 from pathlib import Path
 from typing import Callable
 
@@ -25,7 +26,7 @@ import torch.utils
 from datasets import load_dataset
 from huggingface_hub import snapshot_download, upload_folder
 
-from lerobot.common.datasets.compute_stats import aggregate_stats
+from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
 from lerobot.common.datasets.image_writer import ImageWriter
 from lerobot.common.datasets.utils import (
     append_jsonl,
@@ -630,9 +631,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
         append_jsonl(episode_dict, self.root / "meta/episodes.jsonl")
 
     def delete_episode(self) -> None:
-        pass  # TODO
+        episode_index = self.episode_buffer["episode_index"]
+        if self.image_writer is not None:
+            for cam_key in self.camera_keys:
+                cam_dir = self.image_writer.get_episode_dir(episode_index, cam_key)
+                if cam_dir.is_dir():
+                    shutil.rmtree(cam_dir)
 
-    def consolidate(self) -> None:
+        # Reset the buffer
+        self.episode_buffer = self._create_episode_buffer()
+
+    def consolidate(self, run_compute_stats: bool = True) -> None:
+        if run_compute_stats:
+            logging.info("Computing dataset statistics")
+            self.hf_dataset = self.load_hf_dataset()
+            self.stats = compute_stats(self)
+            write_json()
         pass  # TODO
         # Sanity checks:
         # - [ ] shapes