Add add_frame, empty dataset creation

This commit is contained in:
Simon Alibert 2024-10-21 00:16:52 +02:00
parent 3b925c3dce
commit c1232a01e2
6 changed files with 114 additions and 33 deletions

View File

@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
@ -26,15 +25,17 @@ from datasets import load_dataset
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from lerobot.common.datasets.compute_stats import aggregate_stats from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.image_writer import ImageWriter
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
check_delta_timestamps, check_delta_timestamps,
check_timestamps_sync, check_timestamps_sync,
create_dataset_info, create_empty_dataset_info,
get_delta_indices, get_delta_indices,
get_episode_data_index, get_episode_data_index,
get_hub_safe_version, get_hub_safe_version,
hf_transform_to_torch, hf_transform_to_torch,
load_metadata, load_metadata,
write_json,
) )
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
@ -55,6 +56,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
tolerance_s: float = 1e-4, tolerance_s: float = 1e-4,
download_videos: bool = True, download_videos: bool = True,
video_backend: str | None = None, video_backend: str | None = None,
image_writer: ImageWriter | None = None,
): ):
"""LeRobotDataset encapsulates 3 main things: """LeRobotDataset encapsulates 3 main things:
- metadata: - metadata:
@ -156,6 +158,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.download_videos = download_videos self.download_videos = download_videos
self.video_backend = video_backend if video_backend is not None else "pyav" self.video_backend = video_backend if video_backend is not None else "pyav"
self.image_writer = image_writer
self.episode_buffer = {}
self.delta_indices = None self.delta_indices = None
# Load metadata # Load metadata
@ -296,9 +300,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property @property
def num_samples(self) -> int: def num_samples(self) -> int:
"""Number of samples/frames.""" """Number of samples/frames in selected episodes."""
return len(self.hf_dataset) return len(self.hf_dataset)
@property
def total_frames(self) -> int:
"""Total number of frames saved in this dataset."""
return self.info["total_frames"]
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:
"""Number of episodes selected.""" """Number of episodes selected."""
@ -423,10 +432,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return item return item
def write_info(self) -> None:
with open(self.root / "meta/info.json", "w") as f:
json.dump(self.info, f, indent=4, ensure_ascii=False)
def __repr__(self): def __repr__(self):
return ( return (
f"{self.__class__.__name__}(\n" f"{self.__class__.__name__}(\n"
@ -442,6 +447,49 @@ class LeRobotDataset(torch.utils.data.Dataset):
f")" f")"
) )
def _create_episode_buffer(self) -> dict:
# TODO(aliberts): Handle resume
return {
"chunk": self.total_chunks,
"episode_index": self.total_episodes,
"size": 0,
"frame_index": [],
"timestamp": [],
"next.done": [],
**{key: [] for key in self.keys},
}
def add_frame(self, frame: dict) -> None:
frame_index = self.episode_buffer["size"]
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(frame_index / self.fps)
self.episode_buffer["next.done"].append(False)
# Save all observed modalities except images
for key in self.keys:
self.episode_buffer[key].append(frame[key])
self.episode_buffer["size"] += 1
if self.image_writer is None:
return
# Save images
for cam_key in self.camera_keys:
img_path = self.image_writer.get_image_file_path(
episode_index=self.episode_buffer["episode_index"],
image_key=cam_key,
frame_index=frame_index,
return_str=False,
)
if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True)
self.image_writer.async_save_image(
image=frame[cam_key],
file_path=img_path,
)
@classmethod @classmethod
def create( def create(
cls, cls,
@ -450,24 +498,29 @@ class LeRobotDataset(torch.utils.data.Dataset):
robot: Robot, robot: Robot,
root: Path | None = None, root: Path | None = None,
tolerance_s: float = 1e-4, tolerance_s: float = 1e-4,
image_writer: ImageWriter | None = None,
use_videos: bool = True,
) -> "LeRobotDataset": ) -> "LeRobotDataset":
"""Create a LeRobot Dataset from scratch in order to record data.""" """Create a LeRobot Dataset from scratch in order to record data."""
obj = cls.__new__(cls) obj = cls.__new__(cls)
obj.repo_id = repo_id obj.repo_id = repo_id
obj.root = root if root is not None else LEROBOT_HOME / repo_id obj.root = root if root is not None else LEROBOT_HOME / repo_id
obj._version = CODEBASE_VERSION obj._version = CODEBASE_VERSION
obj.tolerance_s = tolerance_s
obj.image_writer = image_writer
obj.root.mkdir(exist_ok=True, parents=True) if not all(cam.fps == fps for cam in robot.cameras.values()):
obj.info = create_dataset_info(obj._version, fps, robot)
obj.write_info()
obj.fps = fps
if not all(cam.fps == fps for cam in robot.cameras):
logging.warn( logging.warn(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset." f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks" "In this case, frames from lower fps cameras will be repeated to fill in the blanks"
) )
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
write_json(obj.info, obj.root / "meta/info.json")
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj._create_episode_buffer()
# obj.episodes = None # obj.episodes = None
# obj.image_transforms = None # obj.image_transforms = None
# obj.delta_timestamps = None # obj.delta_timestamps = None

View File

@ -75,6 +75,12 @@ def unflatten_dict(d, sep="/"):
return outdict return outdict
def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow) """Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to to torch tensors. Importantly, images are converted from PIL, which corresponds to
@ -146,7 +152,16 @@ def load_metadata(local_dir: Path) -> tuple[dict | list]:
return info, episode_dicts, stats, tasks return info, episode_dicts, stats, tasks
def create_dataset_info(codebase_version: str, fps: int, robot: Robot) -> dict: def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use_videos: bool = True) -> dict:
shapes = {key: len(names) for key, names in robot.names.items()}
camera_shapes = {}
for key, cam in robot.cameras.items():
video_key = f"observation.images.{key}"
camera_shapes[video_key] = {
"width": cam.width,
"height": cam.height,
"channels": cam.channels,
}
return { return {
"codebase_version": codebase_version, "codebase_version": codebase_version,
"data_path": DEFAULT_PARQUET_PATH, "data_path": DEFAULT_PARQUET_PATH,
@ -159,12 +174,12 @@ def create_dataset_info(codebase_version: str, fps: int, robot: Robot) -> dict:
"chunks_size": DEFAULT_CHUNK_SIZE, "chunks_size": DEFAULT_CHUNK_SIZE,
"fps": fps, "fps": fps,
"splits": {}, "splits": {},
# "keys": keys, "keys": list(robot.names),
# "video_keys": video_keys, "video_keys": list(camera_shapes) if use_videos else [],
# "image_keys": image_keys, "image_keys": [] if use_videos else list(camera_shapes),
# "shapes": {**sequence_shapes, **video_shapes, **image_shapes}, "shapes": {**shapes, **camera_shapes},
# "names": names, "names": robot.names,
# "videos": {"videos_path": DEFAULT_VIDEO_PATH} if video_keys else None, "videos": {"videos_path": DEFAULT_VIDEO_PATH} if use_videos else None,
} }
@ -270,6 +285,7 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic
return delta_indices return delta_indices
# TODO(aliberts): remove
def load_previous_and_future_frames( def load_previous_and_future_frames(
item: dict[str, torch.Tensor], item: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset, hf_dataset: datasets.Dataset,
@ -363,6 +379,7 @@ def load_previous_and_future_frames(
return item return item
# TODO(aliberts): remove
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]: def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
""" """
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset. Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
@ -417,6 +434,7 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
return episode_data_index return episode_data_index
# TODO(aliberts): remove
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset: def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
"""Reset the `episode_index` of the provided HuggingFace Dataset. """Reset the `episode_index` of the provided HuggingFace Dataset.
@ -454,7 +472,7 @@ def cycle(iterable):
iterator = iter(iterable) iterator = iter(iterable)
def create_branch(repo_id, *, branch: str, repo_type: str | None = None): def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already """Create a branch on a existing Hugging Face repo. Delete the branch if it already
exists before creating it. exists before creating it.
""" """

View File

@ -192,6 +192,7 @@ class OpenCVCameraConfig:
width: int | None = None width: int | None = None
height: int | None = None height: int | None = None
color_mode: str = "rgb" color_mode: str = "rgb"
channels: int | None = None
rotation: int | None = None rotation: int | None = None
mock: bool = False mock: bool = False
@ -201,6 +202,8 @@ class OpenCVCameraConfig:
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided." f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
) )
self.channels = 3
if self.rotation not in [-90, None, 90, 180]: if self.rotation not in [-90, None, 90, 180]:
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})") raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
@ -268,6 +271,7 @@ class OpenCVCamera:
self.fps = config.fps self.fps = config.fps
self.width = config.width self.width = config.width
self.height = config.height self.height = config.height
self.channels = config.channels
self.color_mode = config.color_mode self.color_mode = config.color_mode
self.mock = config.mock self.mock = config.mock

View File

@ -15,7 +15,8 @@ import torch
import tqdm import tqdm
from termcolor import colored from termcolor import colored
from lerobot.common.datasets.populate_dataset import add_frame, safe_stop_image_writer from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.robot_devices.utils import busy_wait
@ -227,7 +228,7 @@ def control_loop(
control_time_s=None, control_time_s=None,
teleoperate=False, teleoperate=False,
display_cameras=False, display_cameras=False,
dataset=None, dataset: LeRobotDataset | None = None,
events=None, events=None,
policy=None, policy=None,
device=None, device=None,
@ -268,7 +269,8 @@ def control_loop(
action = {"action": action} action = {"action": action}
if dataset is not None: if dataset is not None:
add_frame(dataset, observation, action) frame = {**observation, **action}
dataset.add_frame(frame)
if display_cameras and not is_headless(): if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key] image_keys = [key for key in observation if "image" in key]

View File

@ -349,6 +349,13 @@ class ManipulatorRobot:
self.is_connected = False self.is_connected = False
self.logs = {} self.logs = {}
action_names = [f"{arm}_{motor}" for arm, bus in self.leader_arms.items() for motor in bus.motors]
state_names = [f"{arm}_{motor}" for arm, bus in self.follower_arms.items() for motor in bus.motors]
self.names = {
"action": action_names,
"observation.state": state_names,
}
@property @property
def has_camera(self): def has_camera(self):
return len(self.cameras) > 0 return len(self.cameras) > 0

View File

@ -105,11 +105,11 @@ from pathlib import Path
from typing import List from typing import List
# from safetensors.torch import load_file, save_file # from safetensors.torch import load_file, save_file
from lerobot.common.datasets.image_writer import ImageWriter
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.populate_dataset import ( from lerobot.common.datasets.populate_dataset import (
create_lerobot_dataset, create_lerobot_dataset,
delete_current_episode, delete_current_episode,
init_dataset,
save_current_episode, save_current_episode,
) )
from lerobot.common.robot_devices.control_utils import ( from lerobot.common.robot_devices.control_utils import (
@ -233,16 +233,12 @@ def record(
# Create empty dataset or load existing saved episodes # Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy) sanity_check_dataset_name(repo_id, policy)
dataset = init_dataset( image_writer = ImageWriter(
repo_id, write_dir=root,
root,
force_override,
fps,
video,
write_images=robot.has_camera,
num_image_writer_processes=num_image_writer_processes, num_image_writer_processes=num_image_writer_processes,
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras, num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
) )
dataset = LeRobotDataset.create(repo_id, fps, robot, image_writer=image_writer)
if not robot.is_connected: if not robot.is_connected:
robot.connect() robot.connect()
@ -260,8 +256,9 @@ def record(
if has_method(robot, "teleop_safety_stop"): if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop() robot.teleop_safety_stop()
recorded_episodes = 0
while True: while True:
if dataset["num_episodes"] >= num_episodes: if recorded_episodes >= num_episodes:
break break
episode_index = dataset["num_episodes"] episode_index = dataset["num_episodes"]