Add add_episode & task logic

This commit is contained in:
Simon Alibert 2024-10-21 19:30:20 +02:00
parent 9ebf8b88ec
commit 299451af81
4 changed files with 203 additions and 18 deletions

View File

@ -19,16 +19,19 @@ from pathlib import Path
from typing import Callable
import datasets
import pyarrow.parquet as pq
import torch
import torch.utils
from datasets import load_dataset
from huggingface_hub import snapshot_download
from huggingface_hub import snapshot_download, upload_folder
from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.image_writer import ImageWriter
from lerobot.common.datasets.utils import (
append_jsonl,
check_delta_timestamps,
check_timestamps_sync,
create_branch,
create_empty_dataset_info,
get_delta_indices,
get_episode_data_index,
@ -160,6 +163,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.video_backend = video_backend if video_backend is not None else "pyav"
self.image_writer = image_writer
self.episode_buffer = {}
self.consolidated = True
self.delta_indices = None
# Load metadata
@ -192,6 +196,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
# - [ ] Update episode_index (arg update=True)
# - [ ] Update info.json (arg update=True)
def push_to_repo(self, push_videos: bool = True) -> None:
if not self.consolidated:
raise RuntimeError(
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet."
"Please use the '.consolidate()' method first."
)
ignore_patterns = ["images/"]
if not push_videos:
ignore_patterns.append("videos/")
upload_folder(
repo_id=self.repo_id,
folder_path=self.root,
repo_type="dataset",
ignore_patterns=ignore_patterns,
)
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
def pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,
@ -303,11 +325,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""Number of samples/frames in selected episodes."""
return len(self.hf_dataset)
@property
def total_frames(self) -> int:
"""Total number of frames saved in this dataset."""
return self.info["total_frames"]
@property
def num_episodes(self) -> int:
"""Number of episodes selected."""
@ -318,6 +335,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""Total number of episodes available."""
return self.info["total_episodes"]
@property
def total_frames(self) -> int:
"""Total number of frames saved in this dataset."""
return self.info["total_frames"]
@property
def total_tasks(self) -> int:
"""Total number of different tasks performed in this dataset."""
return self.info["total_tasks"]
@property
def total_chunks(self) -> int:
"""Total number of chunks (groups of episodes)."""
@ -331,7 +358,46 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
self.info.get("shapes")
return self.info["shapes"]
@property
def features(self) -> datasets.Features:
"""Shapes for the different features."""
if self.hf_dataset is not None:
return self.hf_dataset.features
elif self.episode_buffer is None:
raise NotImplementedError(
"Dataset features must be infered from an existing hf_dataset or episode_buffer."
)
features = {}
for key in self.episode_buffer:
if key in ["episode_index", "frame_index", "index", "task_index"]:
features[key] = datasets.Value(dtype="int64")
elif key in ["next.done", "next.success"]:
features[key] = datasets.Value(dtype="bool")
elif key in ["timestamp", "next.reward"]:
features[key] = datasets.Value(dtype="float32")
elif key in self.image_keys:
features[key] = datasets.Image()
elif key in self.keys:
features[key] = datasets.Sequence(
length=self.shapes[key], feature=datasets.Value(dtype="float32")
)
return datasets.Features(features)
@property
def task_to_task_index(self) -> dict:
return {task: task_idx for task_idx, task in self.tasks.items()}
def get_task_index(self, task: str) -> int:
"""
Given a task in natural language, returns its task_index if the task already exists in the dataset,
otherwise creates a new task_index.
"""
task_index = self.task_to_task_index.get(task, None)
return task_index if task_index is not None else self.total_tasks
def current_episode_index(self, idx: int) -> int:
episode_index = self.hf_dataset["episode_index"][idx]
@ -447,12 +513,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
f")"
)
def _create_episode_buffer(self) -> dict:
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
# TODO(aliberts): Handle resume
return {
"chunk": self.total_chunks,
"episode_index": self.total_episodes,
"size": 0,
"episode_index": self.total_episodes if episode_index is None else episode_index,
"task_index": None,
"frame_index": [],
"timestamp": [],
"next.done": [],
@ -490,6 +556,92 @@ class LeRobotDataset(torch.utils.data.Dataset):
file_path=img_path,
)
def add_episode(self, task: str, encode_videos: bool = False) -> None:
"""
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
the hub.
Use encode_videos if you want to encode videos during the saving of each episode. Otherwise,
you can do it later during dataset.consolidate(). This is to give more flexibility on when to spend
time for video encoding.
"""
episode_length = self.episode_buffer.pop("size")
episode_index = self.episode_buffer["episode_index"]
task_index = self.get_task_index(task)
self.episode_buffer["next.done"][-1] = True
for key in self.episode_buffer:
if key in self.keys:
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
elif key == "episode_index":
self.episode_buffer[key] = torch.full((episode_length,), episode_index)
elif key == "task_index":
self.episode_buffer[key] = torch.full((episode_length,), task_index)
else:
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
self._save_episode_table(episode_index)
if encode_videos:
pass # TODO
# Reset the buffer
self.episode_buffer = self._create_episode_buffer()
self.consolidated = False
def _save_episode_table(self, episode_index: int) -> None:
features = self.features
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=features, split="train")
ep_table = ep_dataset._data.table
ep_data_path = self.get_data_file_path(ep_index=episode_index, return_str=False)
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(ep_table, ep_data_path)
def _save_episode_to_metadata(
self, episode_index: int, episode_length: int, task: str, task_index: int
) -> None:
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
if task_index not in self.tasks:
self.info["total_tasks"] += 1
self.tasks[task_index] = task
task_dict = {
"task_index": task_index,
"task": task,
}
append_jsonl(task_dict, self.root / "meta/tasks.jsonl")
chunk = self.get_episode_chunk(episode_index)
if chunk >= self.total_chunks:
self.info["total_chunks"] += 1
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
self.info["total_videos"] += len(self.video_keys)
write_json(self.info, self.root / "meta/info.json")
episode_dict = {
"episode_index": episode_index,
"tasks": [task],
"length": episode_length,
}
append_jsonl(episode_dict, self.root / "meta/episodes.jsonl")
def delete_episode(self) -> None:
pass # TODO
def consolidate(self) -> None:
pass # TODO
# Sanity checks:
# - [ ] shapes
# - [ ] ep_lenghts
# - [ ] number of files
# - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002)
# - [ ] no remaining self.image_writer.dir
self.consolidated = True
@classmethod
def create(
cls,
@ -508,6 +660,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj._version = CODEBASE_VERSION
obj.tolerance_s = tolerance_s
obj.image_writer = image_writer
obj.hf_dataset = None
if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warn(
@ -515,12 +668,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
)
obj.tasks = {}
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
write_json(obj.info, obj.root / "meta/info.json")
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj._create_episode_buffer()
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk.
# It is used to know when certain operations are need (for instance, computing dataset statistics).
# In order to be able to push the dataset to the hub, it needs to be consolidation first.
obj.consolidated = True
# obj.episodes = None
# obj.image_transforms = None
# obj.delta_timestamps = None

View File

@ -81,6 +81,11 @@ def write_json(data: dict, fpath: Path) -> None:
json.dump(data, f, indent=4, ensure_ascii=False)
def append_jsonl(data: dict, fpath: Path) -> None:
with jsonlines.open(fpath, "a") as writer:
writer.write(data)
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to

View File

@ -248,7 +248,7 @@ def control_loop(
if teleoperate and policy is not None:
raise ValueError("When `teleoperate` is True, `policy` should be None.")
if dataset is not None and fps is not None and dataset["fps"] != fps:
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
timestamp = 0

View File

@ -109,8 +109,6 @@ from lerobot.common.datasets.image_writer import ImageWriter
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.populate_dataset import (
create_lerobot_dataset,
delete_current_episode,
save_current_episode,
)
from lerobot.common.robot_devices.control_utils import (
control_loop,
@ -195,6 +193,7 @@ def record(
robot: Robot,
root: str,
repo_id: str,
single_task: str,
pretrained_policy_name_or_path: str | None = None,
policy_overrides: List[str] | None = None,
fps: int | None = None,
@ -219,6 +218,11 @@ def record(
device = None
use_amp = None
if single_task:
task = single_task
else:
raise NotImplementedError("Only single-task recording is supported for now")
# Load pretrained policy
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
@ -235,8 +239,8 @@ def record(
sanity_check_dataset_name(repo_id, policy)
image_writer = ImageWriter(
write_dir=root,
num_image_writer_processes=num_image_writer_processes,
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
)
dataset = LeRobotDataset.create(repo_id, fps, robot, image_writer=image_writer)
@ -261,7 +265,12 @@ def record(
if recorded_episodes >= num_episodes:
break
episode_index = dataset["num_episodes"]
# TODO(aliberts): add task prompt for multitask here. Might need to temporarily disable event if
# input() messes with them.
# if multi_task:
# task = input("Enter your task description: ")
episode_index = dataset.episode_buffer["episode_index"]
log_say(f"Recording episode {episode_index}", play_sounds)
record_episode(
dataset=dataset,
@ -289,11 +298,11 @@ def record(
log_say("Re-record episode", play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
delete_current_episode(dataset)
dataset.delete_episode()
continue
# Increment by one dataset["current_episode_index"]
save_current_episode(dataset)
dataset.add_episode(task)
if events["stop_recording"]:
break
@ -378,9 +387,21 @@ if __name__ == "__main__":
)
parser_record = subparsers.add_parser("record", parents=[base_parser])
task_args = parser_record.add_mutually_exclusive_group(required=True)
parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
task_args.add_argument(
"--single-task",
type=str,
help="A short but accurate description of the task performed during the recording.",
)
# TODO(aliberts): add multi-task support
# task_args.add_argument(
# "--multi-task",
# type=int,
# help="You will need to enter the task performed at the start of each episode.",
# )
parser_record.add_argument(
"--root",
type=Path,