Improve consistency between __init__() and create(), WIP on consolidate
This commit is contained in:
parent
c4c0a43de7
commit
e991a31061
|
@ -13,6 +13,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
|
@ -39,6 +40,7 @@ from lerobot.common.datasets.utils import (
|
|||
get_hub_safe_version,
|
||||
hf_transform_to_torch,
|
||||
load_metadata,
|
||||
unflatten_dict,
|
||||
write_json,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
|
||||
|
@ -163,9 +165,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.download_videos = download_videos
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
self.image_writer = image_writer
|
||||
self.episode_buffer = {}
|
||||
self.consolidated = True
|
||||
self.delta_indices = None
|
||||
self.consolidated = True
|
||||
self.episode_buffer = {}
|
||||
|
||||
# Load metadata
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
@ -501,17 +503,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f"{self.__class__.__name__}\n"
|
||||
f" Repository ID: '{self.repo_id}',\n"
|
||||
f" Number of Samples: {self.num_samples},\n"
|
||||
f" Number of Episodes: {self.num_episodes},\n"
|
||||
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
||||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
f" Camera Keys: {self.camera_keys},\n"
|
||||
f" Video Frame Keys: {self.camera_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f" Codebase Version: {self.info.get('codebase_version', '< v1.6')},\n"
|
||||
f")"
|
||||
f" Selected episodes: {self.episodes},\n"
|
||||
f" Number of selected episodes: {self.num_episodes},\n"
|
||||
f" Number of selected samples: {self.num_samples},\n"
|
||||
f"\n{json.dumps(self.info, indent=4)}\n"
|
||||
)
|
||||
|
||||
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
|
@ -563,12 +560,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
||||
the hub.
|
||||
|
||||
Use encode_videos if you want to encode videos during the saving of each episode. Otherwise,
|
||||
you can do it later during dataset.consolidate(). This is to give more flexibility on when to spend
|
||||
Use 'encode_videos' if you want to encode videos during the saving of each episode. Otherwise,
|
||||
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
|
||||
time for video encoding.
|
||||
"""
|
||||
episode_length = self.episode_buffer.pop("size")
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if episode_index != self.total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError()
|
||||
|
||||
task_index = self.get_task_index(task)
|
||||
self.episode_buffer["next.done"][-1] = True
|
||||
|
||||
|
@ -641,12 +642,30 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
def _update_data_file_names(self) -> None:
|
||||
# TODO(aliberts): remove the need for this hack by removing total_episodes part in data file names.
|
||||
# Must first investigate if this doesn't break hub/datasets features like viewer etc.
|
||||
for ep_idx in range(self.total_episodes):
|
||||
ep_chunk = self.get_episode_chunk(ep_idx)
|
||||
current_file_name = self.data_path.replace("{total_episodes:05d}", "*")
|
||||
current_file_name = current_file_name.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
current_file_name = list(self.root.glob(current_file_name))[0]
|
||||
updated_file_name = self.get_data_file_path(ep_idx)
|
||||
current_file_name.rename(updated_file_name)
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True) -> None:
|
||||
self._update_data_file_names()
|
||||
if run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.stats = compute_stats(self)
|
||||
write_json()
|
||||
serialized_stats = {key: value.tolist() for key, value in self.stats.items()}
|
||||
serialized_stats = unflatten_dict(serialized_stats)
|
||||
write_json(serialized_stats, self.root / "meta/stats.json")
|
||||
else:
|
||||
logging.warning("Skipping computation of the dataset statistics.")
|
||||
|
||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||
pass # TODO
|
||||
# Sanity checks:
|
||||
# - [ ] shapes
|
||||
|
@ -666,6 +685,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
tolerance_s: float = 1e-4,
|
||||
image_writer: ImageWriter | None = None,
|
||||
use_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||
obj = cls.__new__(cls)
|
||||
|
@ -674,15 +694,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
obj._version = CODEBASE_VERSION
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_writer = image_writer
|
||||
obj.hf_dataset = None
|
||||
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warn(
|
||||
logging.warning(
|
||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
||||
)
|
||||
|
||||
obj.tasks = {}
|
||||
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
|
||||
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
|
||||
write_json(obj.info, obj.root / "meta/info.json")
|
||||
|
||||
|
@ -694,14 +713,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# In order to be able to push the dataset to the hub, it needs to be consolidation first.
|
||||
obj.consolidated = True
|
||||
|
||||
# obj.episodes = None
|
||||
# obj.image_transforms = None
|
||||
# obj.delta_timestamps = None
|
||||
# obj.episode_data_index = episode_data_index
|
||||
# obj.stats = stats
|
||||
# obj.info = info if info is not None else {}
|
||||
# obj.videos_dir = videos_dir
|
||||
# obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
obj.episodes = None
|
||||
obj.hf_dataset = None
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
return obj
|
||||
|
||||
|
||||
|
|
|
@ -107,9 +107,6 @@ from typing import List
|
|||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.datasets.image_writer import ImageWriter
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.populate_dataset import (
|
||||
create_lerobot_dataset,
|
||||
)
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
control_loop,
|
||||
has_method,
|
||||
|
@ -210,7 +207,7 @@ def record(
|
|||
force_override=False,
|
||||
display_cameras=True,
|
||||
play_sounds=True,
|
||||
):
|
||||
) -> LeRobotDataset:
|
||||
# TODO(rcadene): Add option to record logs
|
||||
listener = None
|
||||
events = None
|
||||
|
@ -242,7 +239,7 @@ def record(
|
|||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||
)
|
||||
dataset = LeRobotDataset.create(repo_id, fps, robot, image_writer=image_writer)
|
||||
dataset = LeRobotDataset.create(repo_id, fps, robot, root=root, image_writer=image_writer)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
@ -301,8 +298,8 @@ def record(
|
|||
dataset.delete_episode()
|
||||
continue
|
||||
|
||||
# Increment by one dataset["current_episode_index"]
|
||||
dataset.add_episode(task)
|
||||
recorded_episodes += 1
|
||||
|
||||
if events["stop_recording"]:
|
||||
break
|
||||
|
@ -310,10 +307,17 @@ def record(
|
|||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, display_cameras)
|
||||
|
||||
lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
|
||||
logging.info("Waiting for image writer to terminate...")
|
||||
dataset.image_writer.stop()
|
||||
|
||||
dataset.consolidate(run_compute_stats)
|
||||
|
||||
# lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
|
||||
if push_to_hub:
|
||||
dataset.push_to_repo()
|
||||
|
||||
log_say("Exiting", play_sounds)
|
||||
return lerobot_dataset
|
||||
return dataset
|
||||
|
||||
|
||||
@safe_disconnect
|
||||
|
|
Loading…
Reference in New Issue