Improve consistency between __init__() and create(), WIP on consolidate

This commit is contained in:
Simon Alibert 2024-10-22 00:19:25 +02:00
parent c4c0a43de7
commit e991a31061
2 changed files with 55 additions and 34 deletions

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import shutil
@ -39,6 +40,7 @@ from lerobot.common.datasets.utils import (
get_hub_safe_version,
hf_transform_to_torch,
load_metadata,
unflatten_dict,
write_json,
)
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
@ -163,9 +165,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.download_videos = download_videos
self.video_backend = video_backend if video_backend is not None else "pyav"
self.image_writer = image_writer
self.episode_buffer = {}
self.consolidated = True
self.delta_indices = None
self.consolidated = True
self.episode_buffer = {}
# Load metadata
self.root.mkdir(exist_ok=True, parents=True)
@ -501,17 +503,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
def __repr__(self):
return (
f"{self.__class__.__name__}(\n"
f"{self.__class__.__name__}\n"
f" Repository ID: '{self.repo_id}',\n"
f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.camera_keys if self.video else 'N/A'},\n"
f" Transformations: {self.image_transforms},\n"
f" Codebase Version: {self.info.get('codebase_version', '< v1.6')},\n"
f")"
f" Selected episodes: {self.episodes},\n"
f" Number of selected episodes: {self.num_episodes},\n"
f" Number of selected samples: {self.num_samples},\n"
f"\n{json.dumps(self.info, indent=4)}\n"
)
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
@ -563,12 +560,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
the hub.
Use encode_videos if you want to encode videos during the saving of each episode. Otherwise,
you can do it later during dataset.consolidate(). This is to give more flexibility on when to spend
Use 'encode_videos' if you want to encode videos during the saving of each episode. Otherwise,
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
time for video encoding.
"""
episode_length = self.episode_buffer.pop("size")
episode_index = self.episode_buffer["episode_index"]
if episode_index != self.total_episodes:
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError()
task_index = self.get_task_index(task)
self.episode_buffer["next.done"][-1] = True
@ -641,12 +642,30 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Reset the buffer
self.episode_buffer = self._create_episode_buffer()
def _update_data_file_names(self) -> None:
# TODO(aliberts): remove the need for this hack by removing total_episodes part in data file names.
# Must first investigate if this doesn't break hub/datasets features like viewer etc.
for ep_idx in range(self.total_episodes):
ep_chunk = self.get_episode_chunk(ep_idx)
current_file_name = self.data_path.replace("{total_episodes:05d}", "*")
current_file_name = current_file_name.format(episode_chunk=ep_chunk, episode_index=ep_idx)
current_file_name = list(self.root.glob(current_file_name))[0]
updated_file_name = self.get_data_file_path(ep_idx)
current_file_name.rename(updated_file_name)
def consolidate(self, run_compute_stats: bool = True) -> None:
self._update_data_file_names()
if run_compute_stats:
logging.info("Computing dataset statistics")
self.hf_dataset = self.load_hf_dataset()
self.stats = compute_stats(self)
write_json()
serialized_stats = {key: value.tolist() for key, value in self.stats.items()}
serialized_stats = unflatten_dict(serialized_stats)
write_json(serialized_stats, self.root / "meta/stats.json")
else:
logging.warning("Skipping computation of the dataset statistics.")
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
pass # TODO
# Sanity checks:
# - [ ] shapes
@ -666,6 +685,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
tolerance_s: float = 1e-4,
image_writer: ImageWriter | None = None,
use_videos: bool = True,
video_backend: str | None = None,
) -> "LeRobotDataset":
"""Create a LeRobot Dataset from scratch in order to record data."""
obj = cls.__new__(cls)
@ -674,15 +694,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj._version = CODEBASE_VERSION
obj.tolerance_s = tolerance_s
obj.image_writer = image_writer
obj.hf_dataset = None
if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warn(
logging.warning(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
)
obj.tasks = {}
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
write_json(obj.info, obj.root / "meta/info.json")
@ -694,14 +713,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
# In order to be able to push the dataset to the hub, it needs to be consolidation first.
obj.consolidated = True
# obj.episodes = None
# obj.image_transforms = None
# obj.delta_timestamps = None
# obj.episode_data_index = episode_data_index
# obj.stats = stats
# obj.info = info if info is not None else {}
# obj.videos_dir = videos_dir
# obj.video_backend = video_backend if video_backend is not None else "pyav"
obj.episodes = None
obj.hf_dataset = None
obj.image_transforms = None
obj.delta_timestamps = None
obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else "pyav"
return obj

View File

@ -107,9 +107,6 @@ from typing import List
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.image_writer import ImageWriter
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.populate_dataset import (
create_lerobot_dataset,
)
from lerobot.common.robot_devices.control_utils import (
control_loop,
has_method,
@ -210,7 +207,7 @@ def record(
force_override=False,
display_cameras=True,
play_sounds=True,
):
) -> LeRobotDataset:
# TODO(rcadene): Add option to record logs
listener = None
events = None
@ -242,7 +239,7 @@ def record(
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
)
dataset = LeRobotDataset.create(repo_id, fps, robot, image_writer=image_writer)
dataset = LeRobotDataset.create(repo_id, fps, robot, root=root, image_writer=image_writer)
if not robot.is_connected:
robot.connect()
@ -301,8 +298,8 @@ def record(
dataset.delete_episode()
continue
# Increment by one dataset["current_episode_index"]
dataset.add_episode(task)
recorded_episodes += 1
if events["stop_recording"]:
break
@ -310,10 +307,17 @@ def record(
log_say("Stop recording", play_sounds, blocking=True)
stop_recording(robot, listener, display_cameras)
lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
logging.info("Waiting for image writer to terminate...")
dataset.image_writer.stop()
dataset.consolidate(run_compute_stats)
# lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
if push_to_hub:
dataset.push_to_repo()
log_say("Exiting", play_sounds)
return lerobot_dataset
return dataset
@safe_disconnect