Improve consistency between __init__() and create(), WIP on consolidate
This commit is contained in:
parent
c4c0a43de7
commit
e991a31061
|
@ -13,6 +13,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
@ -39,6 +40,7 @@ from lerobot.common.datasets.utils import (
|
||||||
get_hub_safe_version,
|
get_hub_safe_version,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
load_metadata,
|
load_metadata,
|
||||||
|
unflatten_dict,
|
||||||
write_json,
|
write_json,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
|
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
|
||||||
|
@ -163,9 +165,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.download_videos = download_videos
|
self.download_videos = download_videos
|
||||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||||
self.image_writer = image_writer
|
self.image_writer = image_writer
|
||||||
self.episode_buffer = {}
|
|
||||||
self.consolidated = True
|
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
|
self.consolidated = True
|
||||||
|
self.episode_buffer = {}
|
||||||
|
|
||||||
# Load metadata
|
# Load metadata
|
||||||
self.root.mkdir(exist_ok=True, parents=True)
|
self.root.mkdir(exist_ok=True, parents=True)
|
||||||
|
@ -501,17 +503,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (
|
return (
|
||||||
f"{self.__class__.__name__}(\n"
|
f"{self.__class__.__name__}\n"
|
||||||
f" Repository ID: '{self.repo_id}',\n"
|
f" Repository ID: '{self.repo_id}',\n"
|
||||||
f" Number of Samples: {self.num_samples},\n"
|
f" Selected episodes: {self.episodes},\n"
|
||||||
f" Number of Episodes: {self.num_episodes},\n"
|
f" Number of selected episodes: {self.num_episodes},\n"
|
||||||
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
f" Number of selected samples: {self.num_samples},\n"
|
||||||
f" Recorded Frames per Second: {self.fps},\n"
|
f"\n{json.dumps(self.info, indent=4)}\n"
|
||||||
f" Camera Keys: {self.camera_keys},\n"
|
|
||||||
f" Video Frame Keys: {self.camera_keys if self.video else 'N/A'},\n"
|
|
||||||
f" Transformations: {self.image_transforms},\n"
|
|
||||||
f" Codebase Version: {self.info.get('codebase_version', '< v1.6')},\n"
|
|
||||||
f")"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||||
|
@ -563,12 +560,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
||||||
the hub.
|
the hub.
|
||||||
|
|
||||||
Use encode_videos if you want to encode videos during the saving of each episode. Otherwise,
|
Use 'encode_videos' if you want to encode videos during the saving of each episode. Otherwise,
|
||||||
you can do it later during dataset.consolidate(). This is to give more flexibility on when to spend
|
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
|
||||||
time for video encoding.
|
time for video encoding.
|
||||||
"""
|
"""
|
||||||
episode_length = self.episode_buffer.pop("size")
|
episode_length = self.episode_buffer.pop("size")
|
||||||
episode_index = self.episode_buffer["episode_index"]
|
episode_index = self.episode_buffer["episode_index"]
|
||||||
|
if episode_index != self.total_episodes:
|
||||||
|
# TODO(aliberts): Add option to use existing episode_index
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
task_index = self.get_task_index(task)
|
task_index = self.get_task_index(task)
|
||||||
self.episode_buffer["next.done"][-1] = True
|
self.episode_buffer["next.done"][-1] = True
|
||||||
|
|
||||||
|
@ -641,12 +642,30 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# Reset the buffer
|
# Reset the buffer
|
||||||
self.episode_buffer = self._create_episode_buffer()
|
self.episode_buffer = self._create_episode_buffer()
|
||||||
|
|
||||||
|
def _update_data_file_names(self) -> None:
|
||||||
|
# TODO(aliberts): remove the need for this hack by removing total_episodes part in data file names.
|
||||||
|
# Must first investigate if this doesn't break hub/datasets features like viewer etc.
|
||||||
|
for ep_idx in range(self.total_episodes):
|
||||||
|
ep_chunk = self.get_episode_chunk(ep_idx)
|
||||||
|
current_file_name = self.data_path.replace("{total_episodes:05d}", "*")
|
||||||
|
current_file_name = current_file_name.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||||
|
current_file_name = list(self.root.glob(current_file_name))[0]
|
||||||
|
updated_file_name = self.get_data_file_path(ep_idx)
|
||||||
|
current_file_name.rename(updated_file_name)
|
||||||
|
|
||||||
def consolidate(self, run_compute_stats: bool = True) -> None:
|
def consolidate(self, run_compute_stats: bool = True) -> None:
|
||||||
|
self._update_data_file_names()
|
||||||
if run_compute_stats:
|
if run_compute_stats:
|
||||||
logging.info("Computing dataset statistics")
|
logging.info("Computing dataset statistics")
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
self.stats = compute_stats(self)
|
self.stats = compute_stats(self)
|
||||||
write_json()
|
serialized_stats = {key: value.tolist() for key, value in self.stats.items()}
|
||||||
|
serialized_stats = unflatten_dict(serialized_stats)
|
||||||
|
write_json(serialized_stats, self.root / "meta/stats.json")
|
||||||
|
else:
|
||||||
|
logging.warning("Skipping computation of the dataset statistics.")
|
||||||
|
|
||||||
|
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||||
pass # TODO
|
pass # TODO
|
||||||
# Sanity checks:
|
# Sanity checks:
|
||||||
# - [ ] shapes
|
# - [ ] shapes
|
||||||
|
@ -666,6 +685,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
tolerance_s: float = 1e-4,
|
tolerance_s: float = 1e-4,
|
||||||
image_writer: ImageWriter | None = None,
|
image_writer: ImageWriter | None = None,
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
|
video_backend: str | None = None,
|
||||||
) -> "LeRobotDataset":
|
) -> "LeRobotDataset":
|
||||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||||
obj = cls.__new__(cls)
|
obj = cls.__new__(cls)
|
||||||
|
@ -674,15 +694,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
obj._version = CODEBASE_VERSION
|
obj._version = CODEBASE_VERSION
|
||||||
obj.tolerance_s = tolerance_s
|
obj.tolerance_s = tolerance_s
|
||||||
obj.image_writer = image_writer
|
obj.image_writer = image_writer
|
||||||
obj.hf_dataset = None
|
|
||||||
|
|
||||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||||
logging.warn(
|
logging.warning(
|
||||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
||||||
)
|
)
|
||||||
|
|
||||||
obj.tasks = {}
|
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
|
||||||
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
|
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
|
||||||
write_json(obj.info, obj.root / "meta/info.json")
|
write_json(obj.info, obj.root / "meta/info.json")
|
||||||
|
|
||||||
|
@ -694,14 +713,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# In order to be able to push the dataset to the hub, it needs to be consolidation first.
|
# In order to be able to push the dataset to the hub, it needs to be consolidation first.
|
||||||
obj.consolidated = True
|
obj.consolidated = True
|
||||||
|
|
||||||
# obj.episodes = None
|
obj.episodes = None
|
||||||
# obj.image_transforms = None
|
obj.hf_dataset = None
|
||||||
# obj.delta_timestamps = None
|
obj.image_transforms = None
|
||||||
# obj.episode_data_index = episode_data_index
|
obj.delta_timestamps = None
|
||||||
# obj.stats = stats
|
obj.episode_data_index = None
|
||||||
# obj.info = info if info is not None else {}
|
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||||
# obj.videos_dir = videos_dir
|
|
||||||
# obj.video_backend = video_backend if video_backend is not None else "pyav"
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -107,9 +107,6 @@ from typing import List
|
||||||
# from safetensors.torch import load_file, save_file
|
# from safetensors.torch import load_file, save_file
|
||||||
from lerobot.common.datasets.image_writer import ImageWriter
|
from lerobot.common.datasets.image_writer import ImageWriter
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.populate_dataset import (
|
|
||||||
create_lerobot_dataset,
|
|
||||||
)
|
|
||||||
from lerobot.common.robot_devices.control_utils import (
|
from lerobot.common.robot_devices.control_utils import (
|
||||||
control_loop,
|
control_loop,
|
||||||
has_method,
|
has_method,
|
||||||
|
@ -210,7 +207,7 @@ def record(
|
||||||
force_override=False,
|
force_override=False,
|
||||||
display_cameras=True,
|
display_cameras=True,
|
||||||
play_sounds=True,
|
play_sounds=True,
|
||||||
):
|
) -> LeRobotDataset:
|
||||||
# TODO(rcadene): Add option to record logs
|
# TODO(rcadene): Add option to record logs
|
||||||
listener = None
|
listener = None
|
||||||
events = None
|
events = None
|
||||||
|
@ -242,7 +239,7 @@ def record(
|
||||||
num_processes=num_image_writer_processes,
|
num_processes=num_image_writer_processes,
|
||||||
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||||
)
|
)
|
||||||
dataset = LeRobotDataset.create(repo_id, fps, robot, image_writer=image_writer)
|
dataset = LeRobotDataset.create(repo_id, fps, robot, root=root, image_writer=image_writer)
|
||||||
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
@ -301,8 +298,8 @@ def record(
|
||||||
dataset.delete_episode()
|
dataset.delete_episode()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Increment by one dataset["current_episode_index"]
|
|
||||||
dataset.add_episode(task)
|
dataset.add_episode(task)
|
||||||
|
recorded_episodes += 1
|
||||||
|
|
||||||
if events["stop_recording"]:
|
if events["stop_recording"]:
|
||||||
break
|
break
|
||||||
|
@ -310,10 +307,17 @@ def record(
|
||||||
log_say("Stop recording", play_sounds, blocking=True)
|
log_say("Stop recording", play_sounds, blocking=True)
|
||||||
stop_recording(robot, listener, display_cameras)
|
stop_recording(robot, listener, display_cameras)
|
||||||
|
|
||||||
lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
|
logging.info("Waiting for image writer to terminate...")
|
||||||
|
dataset.image_writer.stop()
|
||||||
|
|
||||||
|
dataset.consolidate(run_compute_stats)
|
||||||
|
|
||||||
|
# lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
|
||||||
|
if push_to_hub:
|
||||||
|
dataset.push_to_repo()
|
||||||
|
|
||||||
log_say("Exiting", play_sounds)
|
log_say("Exiting", play_sounds)
|
||||||
return lerobot_dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
@safe_disconnect
|
@safe_disconnect
|
||||||
|
|
Loading…
Reference in New Issue