add delete_episode, WIP on consolidate
This commit is contained in:
parent
299451af81
commit
c4c0a43de7
|
@ -107,6 +107,12 @@ class ImageWriter:
|
|||
)
|
||||
return str(self.dir / fpath) if return_str else self.dir / fpath
|
||||
|
||||
def get_episode_dir(self, episode_index: int, image_key: str, return_str: bool = True) -> str | Path:
|
||||
dir_path = self.get_image_file_path(
|
||||
episode_index=episode_index, image_key=image_key, frame_index=0, return_str=False
|
||||
).parent
|
||||
return str(dir_path) if return_str else dir_path
|
||||
|
||||
def stop(self, timeout=20) -> None:
|
||||
"""Stop the image writer, waiting for all processes or threads to finish."""
|
||||
if self.type == "threads":
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
|
@ -25,7 +26,7 @@ import torch.utils
|
|||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download, upload_folder
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||
from lerobot.common.datasets.image_writer import ImageWriter
|
||||
from lerobot.common.datasets.utils import (
|
||||
append_jsonl,
|
||||
|
@ -630,9 +631,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
append_jsonl(episode_dict, self.root / "meta/episodes.jsonl")
|
||||
|
||||
def delete_episode(self) -> None:
|
||||
pass # TODO
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if self.image_writer is not None:
|
||||
for cam_key in self.camera_keys:
|
||||
cam_dir = self.image_writer.get_episode_dir(episode_index, cam_key)
|
||||
if cam_dir.is_dir():
|
||||
shutil.rmtree(cam_dir)
|
||||
|
||||
def consolidate(self) -> None:
|
||||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True) -> None:
|
||||
if run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.stats = compute_stats(self)
|
||||
write_json()
|
||||
pass # TODO
|
||||
# Sanity checks:
|
||||
# - [ ] shapes
|
||||
|
|
Loading…
Reference in New Issue