Improve consistency between __init__() and create(), WIP on consolidate

This commit is contained in:
Simon Alibert 2024-10-22 00:19:25 +02:00
parent c4c0a43de7
commit e991a31061
2 changed files with 55 additions and 34 deletions

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging import logging
import os import os
import shutil import shutil
@ -39,6 +40,7 @@ from lerobot.common.datasets.utils import (
get_hub_safe_version, get_hub_safe_version,
hf_transform_to_torch, hf_transform_to_torch,
load_metadata, load_metadata,
unflatten_dict,
write_json, write_json,
) )
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
@ -163,9 +165,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.download_videos = download_videos self.download_videos = download_videos
self.video_backend = video_backend if video_backend is not None else "pyav" self.video_backend = video_backend if video_backend is not None else "pyav"
self.image_writer = image_writer self.image_writer = image_writer
self.episode_buffer = {}
self.consolidated = True
self.delta_indices = None self.delta_indices = None
self.consolidated = True
self.episode_buffer = {}
# Load metadata # Load metadata
self.root.mkdir(exist_ok=True, parents=True) self.root.mkdir(exist_ok=True, parents=True)
@ -501,17 +503,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
def __repr__(self): def __repr__(self):
return ( return (
f"{self.__class__.__name__}(\n" f"{self.__class__.__name__}\n"
f" Repository ID: '{self.repo_id}',\n" f" Repository ID: '{self.repo_id}',\n"
f" Number of Samples: {self.num_samples},\n" f" Selected episodes: {self.episodes},\n"
f" Number of Episodes: {self.num_episodes},\n" f" Number of selected episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" f" Number of selected samples: {self.num_samples},\n"
f" Recorded Frames per Second: {self.fps},\n" f"\n{json.dumps(self.info, indent=4)}\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.camera_keys if self.video else 'N/A'},\n"
f" Transformations: {self.image_transforms},\n"
f" Codebase Version: {self.info.get('codebase_version', '< v1.6')},\n"
f")"
) )
def _create_episode_buffer(self, episode_index: int | None = None) -> dict: def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
@ -563,12 +560,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
the hub. the hub.
Use encode_videos if you want to encode videos during the saving of each episode. Otherwise, Use 'encode_videos' if you want to encode videos during the saving of each episode. Otherwise,
you can do it later during dataset.consolidate(). This is to give more flexibility on when to spend you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
time for video encoding. time for video encoding.
""" """
episode_length = self.episode_buffer.pop("size") episode_length = self.episode_buffer.pop("size")
episode_index = self.episode_buffer["episode_index"] episode_index = self.episode_buffer["episode_index"]
if episode_index != self.total_episodes:
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError()
task_index = self.get_task_index(task) task_index = self.get_task_index(task)
self.episode_buffer["next.done"][-1] = True self.episode_buffer["next.done"][-1] = True
@ -641,12 +642,30 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Reset the buffer # Reset the buffer
self.episode_buffer = self._create_episode_buffer() self.episode_buffer = self._create_episode_buffer()
def _update_data_file_names(self) -> None:
# TODO(aliberts): remove the need for this hack by removing total_episodes part in data file names.
# Must first investigate if this doesn't break hub/datasets features like viewer etc.
for ep_idx in range(self.total_episodes):
ep_chunk = self.get_episode_chunk(ep_idx)
current_file_name = self.data_path.replace("{total_episodes:05d}", "*")
current_file_name = current_file_name.format(episode_chunk=ep_chunk, episode_index=ep_idx)
current_file_name = list(self.root.glob(current_file_name))[0]
updated_file_name = self.get_data_file_path(ep_idx)
current_file_name.rename(updated_file_name)
def consolidate(self, run_compute_stats: bool = True) -> None: def consolidate(self, run_compute_stats: bool = True) -> None:
self._update_data_file_names()
if run_compute_stats: if run_compute_stats:
logging.info("Computing dataset statistics") logging.info("Computing dataset statistics")
self.hf_dataset = self.load_hf_dataset() self.hf_dataset = self.load_hf_dataset()
self.stats = compute_stats(self) self.stats = compute_stats(self)
write_json() serialized_stats = {key: value.tolist() for key, value in self.stats.items()}
serialized_stats = unflatten_dict(serialized_stats)
write_json(serialized_stats, self.root / "meta/stats.json")
else:
logging.warning("Skipping computation of the dataset statistics.")
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
pass # TODO pass # TODO
# Sanity checks: # Sanity checks:
# - [ ] shapes # - [ ] shapes
@ -666,6 +685,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
tolerance_s: float = 1e-4, tolerance_s: float = 1e-4,
image_writer: ImageWriter | None = None, image_writer: ImageWriter | None = None,
use_videos: bool = True, use_videos: bool = True,
video_backend: str | None = None,
) -> "LeRobotDataset": ) -> "LeRobotDataset":
"""Create a LeRobot Dataset from scratch in order to record data.""" """Create a LeRobot Dataset from scratch in order to record data."""
obj = cls.__new__(cls) obj = cls.__new__(cls)
@ -674,15 +694,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj._version = CODEBASE_VERSION obj._version = CODEBASE_VERSION
obj.tolerance_s = tolerance_s obj.tolerance_s = tolerance_s
obj.image_writer = image_writer obj.image_writer = image_writer
obj.hf_dataset = None
if not all(cam.fps == fps for cam in robot.cameras.values()): if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warn( logging.warning(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset." f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks" "In this case, frames from lower fps cameras will be repeated to fill in the blanks"
) )
obj.tasks = {} obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos) obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
write_json(obj.info, obj.root / "meta/info.json") write_json(obj.info, obj.root / "meta/info.json")
@ -694,14 +713,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
# In order to be able to push the dataset to the hub, it needs to be consolidation first. # In order to be able to push the dataset to the hub, it needs to be consolidation first.
obj.consolidated = True obj.consolidated = True
# obj.episodes = None obj.episodes = None
# obj.image_transforms = None obj.hf_dataset = None
# obj.delta_timestamps = None obj.image_transforms = None
# obj.episode_data_index = episode_data_index obj.delta_timestamps = None
# obj.stats = stats obj.episode_data_index = None
# obj.info = info if info is not None else {} obj.video_backend = video_backend if video_backend is not None else "pyav"
# obj.videos_dir = videos_dir
# obj.video_backend = video_backend if video_backend is not None else "pyav"
return obj return obj

View File

@ -107,9 +107,6 @@ from typing import List
# from safetensors.torch import load_file, save_file # from safetensors.torch import load_file, save_file
from lerobot.common.datasets.image_writer import ImageWriter from lerobot.common.datasets.image_writer import ImageWriter
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.populate_dataset import (
create_lerobot_dataset,
)
from lerobot.common.robot_devices.control_utils import ( from lerobot.common.robot_devices.control_utils import (
control_loop, control_loop,
has_method, has_method,
@ -210,7 +207,7 @@ def record(
force_override=False, force_override=False,
display_cameras=True, display_cameras=True,
play_sounds=True, play_sounds=True,
): ) -> LeRobotDataset:
# TODO(rcadene): Add option to record logs # TODO(rcadene): Add option to record logs
listener = None listener = None
events = None events = None
@ -242,7 +239,7 @@ def record(
num_processes=num_image_writer_processes, num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * robot.num_cameras, num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
) )
dataset = LeRobotDataset.create(repo_id, fps, robot, image_writer=image_writer) dataset = LeRobotDataset.create(repo_id, fps, robot, root=root, image_writer=image_writer)
if not robot.is_connected: if not robot.is_connected:
robot.connect() robot.connect()
@ -301,8 +298,8 @@ def record(
dataset.delete_episode() dataset.delete_episode()
continue continue
# Increment by one dataset["current_episode_index"]
dataset.add_episode(task) dataset.add_episode(task)
recorded_episodes += 1
if events["stop_recording"]: if events["stop_recording"]:
break break
@ -310,10 +307,17 @@ def record(
log_say("Stop recording", play_sounds, blocking=True) log_say("Stop recording", play_sounds, blocking=True)
stop_recording(robot, listener, display_cameras) stop_recording(robot, listener, display_cameras)
lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds) logging.info("Waiting for image writer to terminate...")
dataset.image_writer.stop()
dataset.consolidate(run_compute_stats)
# lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
if push_to_hub:
dataset.push_to_repo()
log_say("Exiting", play_sounds) log_say("Exiting", play_sounds)
return lerobot_dataset return dataset
@safe_disconnect @safe_disconnect