Add add_episode & task logic
This commit is contained in:
parent
9ebf8b88ec
commit
299451af81
|
@ -19,16 +19,19 @@ from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import pyarrow.parquet as pq
|
||||||
import torch
|
import torch
|
||||||
import torch.utils
|
import torch.utils
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download, upload_folder
|
||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||||
from lerobot.common.datasets.image_writer import ImageWriter
|
from lerobot.common.datasets.image_writer import ImageWriter
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
append_jsonl,
|
||||||
check_delta_timestamps,
|
check_delta_timestamps,
|
||||||
check_timestamps_sync,
|
check_timestamps_sync,
|
||||||
|
create_branch,
|
||||||
create_empty_dataset_info,
|
create_empty_dataset_info,
|
||||||
get_delta_indices,
|
get_delta_indices,
|
||||||
get_episode_data_index,
|
get_episode_data_index,
|
||||||
|
@ -160,6 +163,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||||
self.image_writer = image_writer
|
self.image_writer = image_writer
|
||||||
self.episode_buffer = {}
|
self.episode_buffer = {}
|
||||||
|
self.consolidated = True
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
|
|
||||||
# Load metadata
|
# Load metadata
|
||||||
|
@ -192,6 +196,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# - [ ] Update episode_index (arg update=True)
|
# - [ ] Update episode_index (arg update=True)
|
||||||
# - [ ] Update info.json (arg update=True)
|
# - [ ] Update info.json (arg update=True)
|
||||||
|
|
||||||
|
def push_to_repo(self, push_videos: bool = True) -> None:
|
||||||
|
if not self.consolidated:
|
||||||
|
raise RuntimeError(
|
||||||
|
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet."
|
||||||
|
"Please use the '.consolidate()' method first."
|
||||||
|
)
|
||||||
|
ignore_patterns = ["images/"]
|
||||||
|
if not push_videos:
|
||||||
|
ignore_patterns.append("videos/")
|
||||||
|
|
||||||
|
upload_folder(
|
||||||
|
repo_id=self.repo_id,
|
||||||
|
folder_path=self.root,
|
||||||
|
repo_type="dataset",
|
||||||
|
ignore_patterns=ignore_patterns,
|
||||||
|
)
|
||||||
|
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
|
||||||
|
|
||||||
def pull_from_repo(
|
def pull_from_repo(
|
||||||
self,
|
self,
|
||||||
allow_patterns: list[str] | str | None = None,
|
allow_patterns: list[str] | str | None = None,
|
||||||
|
@ -303,11 +325,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
"""Number of samples/frames in selected episodes."""
|
"""Number of samples/frames in selected episodes."""
|
||||||
return len(self.hf_dataset)
|
return len(self.hf_dataset)
|
||||||
|
|
||||||
@property
|
|
||||||
def total_frames(self) -> int:
|
|
||||||
"""Total number of frames saved in this dataset."""
|
|
||||||
return self.info["total_frames"]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
"""Number of episodes selected."""
|
"""Number of episodes selected."""
|
||||||
|
@ -318,6 +335,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
"""Total number of episodes available."""
|
"""Total number of episodes available."""
|
||||||
return self.info["total_episodes"]
|
return self.info["total_episodes"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_frames(self) -> int:
|
||||||
|
"""Total number of frames saved in this dataset."""
|
||||||
|
return self.info["total_frames"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_tasks(self) -> int:
|
||||||
|
"""Total number of different tasks performed in this dataset."""
|
||||||
|
return self.info["total_tasks"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total_chunks(self) -> int:
|
def total_chunks(self) -> int:
|
||||||
"""Total number of chunks (groups of episodes)."""
|
"""Total number of chunks (groups of episodes)."""
|
||||||
|
@ -331,7 +358,46 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
@property
|
@property
|
||||||
def shapes(self) -> dict:
|
def shapes(self) -> dict:
|
||||||
"""Shapes for the different features."""
|
"""Shapes for the different features."""
|
||||||
self.info.get("shapes")
|
return self.info["shapes"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def features(self) -> datasets.Features:
|
||||||
|
"""Shapes for the different features."""
|
||||||
|
if self.hf_dataset is not None:
|
||||||
|
return self.hf_dataset.features
|
||||||
|
elif self.episode_buffer is None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Dataset features must be infered from an existing hf_dataset or episode_buffer."
|
||||||
|
)
|
||||||
|
|
||||||
|
features = {}
|
||||||
|
for key in self.episode_buffer:
|
||||||
|
if key in ["episode_index", "frame_index", "index", "task_index"]:
|
||||||
|
features[key] = datasets.Value(dtype="int64")
|
||||||
|
elif key in ["next.done", "next.success"]:
|
||||||
|
features[key] = datasets.Value(dtype="bool")
|
||||||
|
elif key in ["timestamp", "next.reward"]:
|
||||||
|
features[key] = datasets.Value(dtype="float32")
|
||||||
|
elif key in self.image_keys:
|
||||||
|
features[key] = datasets.Image()
|
||||||
|
elif key in self.keys:
|
||||||
|
features[key] = datasets.Sequence(
|
||||||
|
length=self.shapes[key], feature=datasets.Value(dtype="float32")
|
||||||
|
)
|
||||||
|
|
||||||
|
return datasets.Features(features)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_to_task_index(self) -> dict:
|
||||||
|
return {task: task_idx for task_idx, task in self.tasks.items()}
|
||||||
|
|
||||||
|
def get_task_index(self, task: str) -> int:
|
||||||
|
"""
|
||||||
|
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||||
|
otherwise creates a new task_index.
|
||||||
|
"""
|
||||||
|
task_index = self.task_to_task_index.get(task, None)
|
||||||
|
return task_index if task_index is not None else self.total_tasks
|
||||||
|
|
||||||
def current_episode_index(self, idx: int) -> int:
|
def current_episode_index(self, idx: int) -> int:
|
||||||
episode_index = self.hf_dataset["episode_index"][idx]
|
episode_index = self.hf_dataset["episode_index"][idx]
|
||||||
|
@ -447,12 +513,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
f")"
|
f")"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_episode_buffer(self) -> dict:
|
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||||
# TODO(aliberts): Handle resume
|
# TODO(aliberts): Handle resume
|
||||||
return {
|
return {
|
||||||
"chunk": self.total_chunks,
|
|
||||||
"episode_index": self.total_episodes,
|
|
||||||
"size": 0,
|
"size": 0,
|
||||||
|
"episode_index": self.total_episodes if episode_index is None else episode_index,
|
||||||
|
"task_index": None,
|
||||||
"frame_index": [],
|
"frame_index": [],
|
||||||
"timestamp": [],
|
"timestamp": [],
|
||||||
"next.done": [],
|
"next.done": [],
|
||||||
|
@ -490,6 +556,92 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
file_path=img_path,
|
file_path=img_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def add_episode(self, task: str, encode_videos: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
|
||||||
|
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
||||||
|
the hub.
|
||||||
|
|
||||||
|
Use encode_videos if you want to encode videos during the saving of each episode. Otherwise,
|
||||||
|
you can do it later during dataset.consolidate(). This is to give more flexibility on when to spend
|
||||||
|
time for video encoding.
|
||||||
|
"""
|
||||||
|
episode_length = self.episode_buffer.pop("size")
|
||||||
|
episode_index = self.episode_buffer["episode_index"]
|
||||||
|
task_index = self.get_task_index(task)
|
||||||
|
self.episode_buffer["next.done"][-1] = True
|
||||||
|
|
||||||
|
for key in self.episode_buffer:
|
||||||
|
if key in self.keys:
|
||||||
|
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
||||||
|
elif key == "episode_index":
|
||||||
|
self.episode_buffer[key] = torch.full((episode_length,), episode_index)
|
||||||
|
elif key == "task_index":
|
||||||
|
self.episode_buffer[key] = torch.full((episode_length,), task_index)
|
||||||
|
else:
|
||||||
|
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
|
||||||
|
|
||||||
|
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
|
||||||
|
self._save_episode_table(episode_index)
|
||||||
|
|
||||||
|
if encode_videos:
|
||||||
|
pass # TODO
|
||||||
|
|
||||||
|
# Reset the buffer
|
||||||
|
self.episode_buffer = self._create_episode_buffer()
|
||||||
|
self.consolidated = False
|
||||||
|
|
||||||
|
def _save_episode_table(self, episode_index: int) -> None:
|
||||||
|
features = self.features
|
||||||
|
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=features, split="train")
|
||||||
|
ep_table = ep_dataset._data.table
|
||||||
|
ep_data_path = self.get_data_file_path(ep_index=episode_index, return_str=False)
|
||||||
|
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
pq.write_table(ep_table, ep_data_path)
|
||||||
|
|
||||||
|
def _save_episode_to_metadata(
|
||||||
|
self, episode_index: int, episode_length: int, task: str, task_index: int
|
||||||
|
) -> None:
|
||||||
|
self.info["total_episodes"] += 1
|
||||||
|
self.info["total_frames"] += episode_length
|
||||||
|
|
||||||
|
if task_index not in self.tasks:
|
||||||
|
self.info["total_tasks"] += 1
|
||||||
|
self.tasks[task_index] = task
|
||||||
|
task_dict = {
|
||||||
|
"task_index": task_index,
|
||||||
|
"task": task,
|
||||||
|
}
|
||||||
|
append_jsonl(task_dict, self.root / "meta/tasks.jsonl")
|
||||||
|
|
||||||
|
chunk = self.get_episode_chunk(episode_index)
|
||||||
|
if chunk >= self.total_chunks:
|
||||||
|
self.info["total_chunks"] += 1
|
||||||
|
|
||||||
|
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||||
|
self.info["total_videos"] += len(self.video_keys)
|
||||||
|
write_json(self.info, self.root / "meta/info.json")
|
||||||
|
|
||||||
|
episode_dict = {
|
||||||
|
"episode_index": episode_index,
|
||||||
|
"tasks": [task],
|
||||||
|
"length": episode_length,
|
||||||
|
}
|
||||||
|
append_jsonl(episode_dict, self.root / "meta/episodes.jsonl")
|
||||||
|
|
||||||
|
def delete_episode(self) -> None:
|
||||||
|
pass # TODO
|
||||||
|
|
||||||
|
def consolidate(self) -> None:
|
||||||
|
pass # TODO
|
||||||
|
# Sanity checks:
|
||||||
|
# - [ ] shapes
|
||||||
|
# - [ ] ep_lenghts
|
||||||
|
# - [ ] number of files
|
||||||
|
# - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002)
|
||||||
|
# - [ ] no remaining self.image_writer.dir
|
||||||
|
self.consolidated = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
|
@ -508,6 +660,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
obj._version = CODEBASE_VERSION
|
obj._version = CODEBASE_VERSION
|
||||||
obj.tolerance_s = tolerance_s
|
obj.tolerance_s = tolerance_s
|
||||||
obj.image_writer = image_writer
|
obj.image_writer = image_writer
|
||||||
|
obj.hf_dataset = None
|
||||||
|
|
||||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||||
logging.warn(
|
logging.warn(
|
||||||
|
@ -515,12 +668,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
obj.tasks = {}
|
||||||
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
|
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
|
||||||
write_json(obj.info, obj.root / "meta/info.json")
|
write_json(obj.info, obj.root / "meta/info.json")
|
||||||
|
|
||||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||||
obj.episode_buffer = obj._create_episode_buffer()
|
obj.episode_buffer = obj._create_episode_buffer()
|
||||||
|
|
||||||
|
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk.
|
||||||
|
# It is used to know when certain operations are need (for instance, computing dataset statistics).
|
||||||
|
# In order to be able to push the dataset to the hub, it needs to be consolidation first.
|
||||||
|
obj.consolidated = True
|
||||||
|
|
||||||
# obj.episodes = None
|
# obj.episodes = None
|
||||||
# obj.image_transforms = None
|
# obj.image_transforms = None
|
||||||
# obj.delta_timestamps = None
|
# obj.delta_timestamps = None
|
||||||
|
|
|
@ -81,6 +81,11 @@ def write_json(data: dict, fpath: Path) -> None:
|
||||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def append_jsonl(data: dict, fpath: Path) -> None:
|
||||||
|
with jsonlines.open(fpath, "a") as writer:
|
||||||
|
writer.write(data)
|
||||||
|
|
||||||
|
|
||||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||||
|
|
|
@ -248,7 +248,7 @@ def control_loop(
|
||||||
if teleoperate and policy is not None:
|
if teleoperate and policy is not None:
|
||||||
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
||||||
|
|
||||||
if dataset is not None and fps is not None and dataset["fps"] != fps:
|
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||||
|
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
|
|
|
@ -109,8 +109,6 @@ from lerobot.common.datasets.image_writer import ImageWriter
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.populate_dataset import (
|
from lerobot.common.datasets.populate_dataset import (
|
||||||
create_lerobot_dataset,
|
create_lerobot_dataset,
|
||||||
delete_current_episode,
|
|
||||||
save_current_episode,
|
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.control_utils import (
|
from lerobot.common.robot_devices.control_utils import (
|
||||||
control_loop,
|
control_loop,
|
||||||
|
@ -195,6 +193,7 @@ def record(
|
||||||
robot: Robot,
|
robot: Robot,
|
||||||
root: str,
|
root: str,
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
|
single_task: str,
|
||||||
pretrained_policy_name_or_path: str | None = None,
|
pretrained_policy_name_or_path: str | None = None,
|
||||||
policy_overrides: List[str] | None = None,
|
policy_overrides: List[str] | None = None,
|
||||||
fps: int | None = None,
|
fps: int | None = None,
|
||||||
|
@ -219,6 +218,11 @@ def record(
|
||||||
device = None
|
device = None
|
||||||
use_amp = None
|
use_amp = None
|
||||||
|
|
||||||
|
if single_task:
|
||||||
|
task = single_task
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only single-task recording is supported for now")
|
||||||
|
|
||||||
# Load pretrained policy
|
# Load pretrained policy
|
||||||
if pretrained_policy_name_or_path is not None:
|
if pretrained_policy_name_or_path is not None:
|
||||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||||
|
@ -235,8 +239,8 @@ def record(
|
||||||
sanity_check_dataset_name(repo_id, policy)
|
sanity_check_dataset_name(repo_id, policy)
|
||||||
image_writer = ImageWriter(
|
image_writer = ImageWriter(
|
||||||
write_dir=root,
|
write_dir=root,
|
||||||
num_image_writer_processes=num_image_writer_processes,
|
num_processes=num_image_writer_processes,
|
||||||
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||||
)
|
)
|
||||||
dataset = LeRobotDataset.create(repo_id, fps, robot, image_writer=image_writer)
|
dataset = LeRobotDataset.create(repo_id, fps, robot, image_writer=image_writer)
|
||||||
|
|
||||||
|
@ -261,7 +265,12 @@ def record(
|
||||||
if recorded_episodes >= num_episodes:
|
if recorded_episodes >= num_episodes:
|
||||||
break
|
break
|
||||||
|
|
||||||
episode_index = dataset["num_episodes"]
|
# TODO(aliberts): add task prompt for multitask here. Might need to temporarily disable event if
|
||||||
|
# input() messes with them.
|
||||||
|
# if multi_task:
|
||||||
|
# task = input("Enter your task description: ")
|
||||||
|
|
||||||
|
episode_index = dataset.episode_buffer["episode_index"]
|
||||||
log_say(f"Recording episode {episode_index}", play_sounds)
|
log_say(f"Recording episode {episode_index}", play_sounds)
|
||||||
record_episode(
|
record_episode(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
|
@ -289,11 +298,11 @@ def record(
|
||||||
log_say("Re-record episode", play_sounds)
|
log_say("Re-record episode", play_sounds)
|
||||||
events["rerecord_episode"] = False
|
events["rerecord_episode"] = False
|
||||||
events["exit_early"] = False
|
events["exit_early"] = False
|
||||||
delete_current_episode(dataset)
|
dataset.delete_episode()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Increment by one dataset["current_episode_index"]
|
# Increment by one dataset["current_episode_index"]
|
||||||
save_current_episode(dataset)
|
dataset.add_episode(task)
|
||||||
|
|
||||||
if events["stop_recording"]:
|
if events["stop_recording"]:
|
||||||
break
|
break
|
||||||
|
@ -378,9 +387,21 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
|
|
||||||
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||||
|
task_args = parser_record.add_mutually_exclusive_group(required=True)
|
||||||
parser_record.add_argument(
|
parser_record.add_argument(
|
||||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||||
)
|
)
|
||||||
|
task_args.add_argument(
|
||||||
|
"--single-task",
|
||||||
|
type=str,
|
||||||
|
help="A short but accurate description of the task performed during the recording.",
|
||||||
|
)
|
||||||
|
# TODO(aliberts): add multi-task support
|
||||||
|
# task_args.add_argument(
|
||||||
|
# "--multi-task",
|
||||||
|
# type=int,
|
||||||
|
# help="You will need to enter the task performed at the start of each episode.",
|
||||||
|
# )
|
||||||
parser_record.add_argument(
|
parser_record.add_argument(
|
||||||
"--root",
|
"--root",
|
||||||
type=Path,
|
type=Path,
|
||||||
|
|
Loading…
Reference in New Issue