Add add_frame, empty dataset creation
This commit is contained in:
parent
3b925c3dce
commit
c1232a01e2
|
@ -13,7 +13,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -26,15 +25,17 @@ from datasets import load_dataset
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||||
|
from lerobot.common.datasets.image_writer import ImageWriter
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
check_delta_timestamps,
|
check_delta_timestamps,
|
||||||
check_timestamps_sync,
|
check_timestamps_sync,
|
||||||
create_dataset_info,
|
create_empty_dataset_info,
|
||||||
get_delta_indices,
|
get_delta_indices,
|
||||||
get_episode_data_index,
|
get_episode_data_index,
|
||||||
get_hub_safe_version,
|
get_hub_safe_version,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
load_metadata,
|
load_metadata,
|
||||||
|
write_json,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
|
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
|
@ -55,6 +56,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
tolerance_s: float = 1e-4,
|
tolerance_s: float = 1e-4,
|
||||||
download_videos: bool = True,
|
download_videos: bool = True,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
|
image_writer: ImageWriter | None = None,
|
||||||
):
|
):
|
||||||
"""LeRobotDataset encapsulates 3 main things:
|
"""LeRobotDataset encapsulates 3 main things:
|
||||||
- metadata:
|
- metadata:
|
||||||
|
@ -156,6 +158,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.download_videos = download_videos
|
self.download_videos = download_videos
|
||||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||||
|
self.image_writer = image_writer
|
||||||
|
self.episode_buffer = {}
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
|
|
||||||
# Load metadata
|
# Load metadata
|
||||||
|
@ -296,9 +300,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
"""Number of samples/frames."""
|
"""Number of samples/frames in selected episodes."""
|
||||||
return len(self.hf_dataset)
|
return len(self.hf_dataset)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_frames(self) -> int:
|
||||||
|
"""Total number of frames saved in this dataset."""
|
||||||
|
return self.info["total_frames"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
"""Number of episodes selected."""
|
"""Number of episodes selected."""
|
||||||
|
@ -423,10 +432,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def write_info(self) -> None:
|
|
||||||
with open(self.root / "meta/info.json", "w") as f:
|
|
||||||
json.dump(self.info, f, indent=4, ensure_ascii=False)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (
|
return (
|
||||||
f"{self.__class__.__name__}(\n"
|
f"{self.__class__.__name__}(\n"
|
||||||
|
@ -442,6 +447,49 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
f")"
|
f")"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _create_episode_buffer(self) -> dict:
|
||||||
|
# TODO(aliberts): Handle resume
|
||||||
|
return {
|
||||||
|
"chunk": self.total_chunks,
|
||||||
|
"episode_index": self.total_episodes,
|
||||||
|
"size": 0,
|
||||||
|
"frame_index": [],
|
||||||
|
"timestamp": [],
|
||||||
|
"next.done": [],
|
||||||
|
**{key: [] for key in self.keys},
|
||||||
|
}
|
||||||
|
|
||||||
|
def add_frame(self, frame: dict) -> None:
|
||||||
|
frame_index = self.episode_buffer["size"]
|
||||||
|
self.episode_buffer["frame_index"].append(frame_index)
|
||||||
|
self.episode_buffer["timestamp"].append(frame_index / self.fps)
|
||||||
|
self.episode_buffer["next.done"].append(False)
|
||||||
|
|
||||||
|
# Save all observed modalities except images
|
||||||
|
for key in self.keys:
|
||||||
|
self.episode_buffer[key].append(frame[key])
|
||||||
|
|
||||||
|
self.episode_buffer["size"] += 1
|
||||||
|
|
||||||
|
if self.image_writer is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save images
|
||||||
|
for cam_key in self.camera_keys:
|
||||||
|
img_path = self.image_writer.get_image_file_path(
|
||||||
|
episode_index=self.episode_buffer["episode_index"],
|
||||||
|
image_key=cam_key,
|
||||||
|
frame_index=frame_index,
|
||||||
|
return_str=False,
|
||||||
|
)
|
||||||
|
if frame_index == 0:
|
||||||
|
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.image_writer.async_save_image(
|
||||||
|
image=frame[cam_key],
|
||||||
|
file_path=img_path,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
|
@ -450,24 +498,29 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
robot: Robot,
|
robot: Robot,
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
tolerance_s: float = 1e-4,
|
tolerance_s: float = 1e-4,
|
||||||
|
image_writer: ImageWriter | None = None,
|
||||||
|
use_videos: bool = True,
|
||||||
) -> "LeRobotDataset":
|
) -> "LeRobotDataset":
|
||||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||||
obj = cls.__new__(cls)
|
obj = cls.__new__(cls)
|
||||||
obj.repo_id = repo_id
|
obj.repo_id = repo_id
|
||||||
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||||
obj._version = CODEBASE_VERSION
|
obj._version = CODEBASE_VERSION
|
||||||
|
obj.tolerance_s = tolerance_s
|
||||||
|
obj.image_writer = image_writer
|
||||||
|
|
||||||
obj.root.mkdir(exist_ok=True, parents=True)
|
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||||
obj.info = create_dataset_info(obj._version, fps, robot)
|
|
||||||
obj.write_info()
|
|
||||||
obj.fps = fps
|
|
||||||
|
|
||||||
if not all(cam.fps == fps for cam in robot.cameras):
|
|
||||||
logging.warn(
|
logging.warn(
|
||||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
|
||||||
|
write_json(obj.info, obj.root / "meta/info.json")
|
||||||
|
|
||||||
|
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||||
|
obj.episode_buffer = obj._create_episode_buffer()
|
||||||
|
|
||||||
# obj.episodes = None
|
# obj.episodes = None
|
||||||
# obj.image_transforms = None
|
# obj.image_transforms = None
|
||||||
# obj.delta_timestamps = None
|
# obj.delta_timestamps = None
|
||||||
|
|
|
@ -75,6 +75,12 @@ def unflatten_dict(d, sep="/"):
|
||||||
return outdict
|
return outdict
|
||||||
|
|
||||||
|
|
||||||
|
def write_json(data: dict, fpath: Path) -> None:
|
||||||
|
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
|
with open(fpath, "w") as f:
|
||||||
|
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||||
|
@ -146,7 +152,16 @@ def load_metadata(local_dir: Path) -> tuple[dict | list]:
|
||||||
return info, episode_dicts, stats, tasks
|
return info, episode_dicts, stats, tasks
|
||||||
|
|
||||||
|
|
||||||
def create_dataset_info(codebase_version: str, fps: int, robot: Robot) -> dict:
|
def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use_videos: bool = True) -> dict:
|
||||||
|
shapes = {key: len(names) for key, names in robot.names.items()}
|
||||||
|
camera_shapes = {}
|
||||||
|
for key, cam in robot.cameras.items():
|
||||||
|
video_key = f"observation.images.{key}"
|
||||||
|
camera_shapes[video_key] = {
|
||||||
|
"width": cam.width,
|
||||||
|
"height": cam.height,
|
||||||
|
"channels": cam.channels,
|
||||||
|
}
|
||||||
return {
|
return {
|
||||||
"codebase_version": codebase_version,
|
"codebase_version": codebase_version,
|
||||||
"data_path": DEFAULT_PARQUET_PATH,
|
"data_path": DEFAULT_PARQUET_PATH,
|
||||||
|
@ -159,12 +174,12 @@ def create_dataset_info(codebase_version: str, fps: int, robot: Robot) -> dict:
|
||||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"splits": {},
|
"splits": {},
|
||||||
# "keys": keys,
|
"keys": list(robot.names),
|
||||||
# "video_keys": video_keys,
|
"video_keys": list(camera_shapes) if use_videos else [],
|
||||||
# "image_keys": image_keys,
|
"image_keys": [] if use_videos else list(camera_shapes),
|
||||||
# "shapes": {**sequence_shapes, **video_shapes, **image_shapes},
|
"shapes": {**shapes, **camera_shapes},
|
||||||
# "names": names,
|
"names": robot.names,
|
||||||
# "videos": {"videos_path": DEFAULT_VIDEO_PATH} if video_keys else None,
|
"videos": {"videos_path": DEFAULT_VIDEO_PATH} if use_videos else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -270,6 +285,7 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic
|
||||||
return delta_indices
|
return delta_indices
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(aliberts): remove
|
||||||
def load_previous_and_future_frames(
|
def load_previous_and_future_frames(
|
||||||
item: dict[str, torch.Tensor],
|
item: dict[str, torch.Tensor],
|
||||||
hf_dataset: datasets.Dataset,
|
hf_dataset: datasets.Dataset,
|
||||||
|
@ -363,6 +379,7 @@ def load_previous_and_future_frames(
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(aliberts): remove
|
||||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
||||||
|
@ -417,6 +434,7 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
|
||||||
return episode_data_index
|
return episode_data_index
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(aliberts): remove
|
||||||
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
||||||
"""Reset the `episode_index` of the provided HuggingFace Dataset.
|
"""Reset the `episode_index` of the provided HuggingFace Dataset.
|
||||||
|
|
||||||
|
@ -454,7 +472,7 @@ def cycle(iterable):
|
||||||
iterator = iter(iterable)
|
iterator = iter(iterable)
|
||||||
|
|
||||||
|
|
||||||
def create_branch(repo_id, *, branch: str, repo_type: str | None = None):
|
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
|
||||||
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
|
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
|
||||||
exists before creating it.
|
exists before creating it.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -192,6 +192,7 @@ class OpenCVCameraConfig:
|
||||||
width: int | None = None
|
width: int | None = None
|
||||||
height: int | None = None
|
height: int | None = None
|
||||||
color_mode: str = "rgb"
|
color_mode: str = "rgb"
|
||||||
|
channels: int | None = None
|
||||||
rotation: int | None = None
|
rotation: int | None = None
|
||||||
mock: bool = False
|
mock: bool = False
|
||||||
|
|
||||||
|
@ -201,6 +202,8 @@ class OpenCVCameraConfig:
|
||||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.channels = 3
|
||||||
|
|
||||||
if self.rotation not in [-90, None, 90, 180]:
|
if self.rotation not in [-90, None, 90, 180]:
|
||||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||||
|
|
||||||
|
@ -268,6 +271,7 @@ class OpenCVCamera:
|
||||||
self.fps = config.fps
|
self.fps = config.fps
|
||||||
self.width = config.width
|
self.width = config.width
|
||||||
self.height = config.height
|
self.height = config.height
|
||||||
|
self.channels = config.channels
|
||||||
self.color_mode = config.color_mode
|
self.color_mode = config.color_mode
|
||||||
self.mock = config.mock
|
self.mock = config.mock
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,8 @@ import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from lerobot.common.datasets.populate_dataset import add_frame, safe_stop_image_writer
|
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
from lerobot.common.robot_devices.utils import busy_wait
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
|
@ -227,7 +228,7 @@ def control_loop(
|
||||||
control_time_s=None,
|
control_time_s=None,
|
||||||
teleoperate=False,
|
teleoperate=False,
|
||||||
display_cameras=False,
|
display_cameras=False,
|
||||||
dataset=None,
|
dataset: LeRobotDataset | None = None,
|
||||||
events=None,
|
events=None,
|
||||||
policy=None,
|
policy=None,
|
||||||
device=None,
|
device=None,
|
||||||
|
@ -268,7 +269,8 @@ def control_loop(
|
||||||
action = {"action": action}
|
action = {"action": action}
|
||||||
|
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
add_frame(dataset, observation, action)
|
frame = {**observation, **action}
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
if display_cameras and not is_headless():
|
if display_cameras and not is_headless():
|
||||||
image_keys = [key for key in observation if "image" in key]
|
image_keys = [key for key in observation if "image" in key]
|
||||||
|
|
|
@ -349,6 +349,13 @@ class ManipulatorRobot:
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
self.logs = {}
|
self.logs = {}
|
||||||
|
|
||||||
|
action_names = [f"{arm}_{motor}" for arm, bus in self.leader_arms.items() for motor in bus.motors]
|
||||||
|
state_names = [f"{arm}_{motor}" for arm, bus in self.follower_arms.items() for motor in bus.motors]
|
||||||
|
self.names = {
|
||||||
|
"action": action_names,
|
||||||
|
"observation.state": state_names,
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_camera(self):
|
def has_camera(self):
|
||||||
return len(self.cameras) > 0
|
return len(self.cameras) > 0
|
||||||
|
|
|
@ -105,11 +105,11 @@ from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
# from safetensors.torch import load_file, save_file
|
# from safetensors.torch import load_file, save_file
|
||||||
|
from lerobot.common.datasets.image_writer import ImageWriter
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.populate_dataset import (
|
from lerobot.common.datasets.populate_dataset import (
|
||||||
create_lerobot_dataset,
|
create_lerobot_dataset,
|
||||||
delete_current_episode,
|
delete_current_episode,
|
||||||
init_dataset,
|
|
||||||
save_current_episode,
|
save_current_episode,
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.control_utils import (
|
from lerobot.common.robot_devices.control_utils import (
|
||||||
|
@ -233,16 +233,12 @@ def record(
|
||||||
|
|
||||||
# Create empty dataset or load existing saved episodes
|
# Create empty dataset or load existing saved episodes
|
||||||
sanity_check_dataset_name(repo_id, policy)
|
sanity_check_dataset_name(repo_id, policy)
|
||||||
dataset = init_dataset(
|
image_writer = ImageWriter(
|
||||||
repo_id,
|
write_dir=root,
|
||||||
root,
|
|
||||||
force_override,
|
|
||||||
fps,
|
|
||||||
video,
|
|
||||||
write_images=robot.has_camera,
|
|
||||||
num_image_writer_processes=num_image_writer_processes,
|
num_image_writer_processes=num_image_writer_processes,
|
||||||
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||||
)
|
)
|
||||||
|
dataset = LeRobotDataset.create(repo_id, fps, robot, image_writer=image_writer)
|
||||||
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
@ -260,8 +256,9 @@ def record(
|
||||||
if has_method(robot, "teleop_safety_stop"):
|
if has_method(robot, "teleop_safety_stop"):
|
||||||
robot.teleop_safety_stop()
|
robot.teleop_safety_stop()
|
||||||
|
|
||||||
|
recorded_episodes = 0
|
||||||
while True:
|
while True:
|
||||||
if dataset["num_episodes"] >= num_episodes:
|
if recorded_episodes >= num_episodes:
|
||||||
break
|
break
|
||||||
|
|
||||||
episode_index = dataset["num_episodes"]
|
episode_index = dataset["num_episodes"]
|
||||||
|
|
Loading…
Reference in New Issue