Add add_episode & task logic
This commit is contained in:
parent
9ebf8b88ec
commit
299451af81
|
@ -19,16 +19,19 @@ from pathlib import Path
|
|||
from typing import Callable
|
||||
|
||||
import datasets
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
import torch.utils
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub import snapshot_download, upload_folder
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.common.datasets.image_writer import ImageWriter
|
||||
from lerobot.common.datasets.utils import (
|
||||
append_jsonl,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
create_branch,
|
||||
create_empty_dataset_info,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
|
@ -160,6 +163,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
self.image_writer = image_writer
|
||||
self.episode_buffer = {}
|
||||
self.consolidated = True
|
||||
self.delta_indices = None
|
||||
|
||||
# Load metadata
|
||||
|
@ -192,6 +196,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# - [ ] Update episode_index (arg update=True)
|
||||
# - [ ] Update info.json (arg update=True)
|
||||
|
||||
def push_to_repo(self, push_videos: bool = True) -> None:
|
||||
if not self.consolidated:
|
||||
raise RuntimeError(
|
||||
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet."
|
||||
"Please use the '.consolidate()' method first."
|
||||
)
|
||||
ignore_patterns = ["images/"]
|
||||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
|
||||
upload_folder(
|
||||
repo_id=self.repo_id,
|
||||
folder_path=self.root,
|
||||
repo_type="dataset",
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
|
@ -303,11 +325,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"""Number of samples/frames in selected episodes."""
|
||||
return len(self.hf_dataset)
|
||||
|
||||
@property
|
||||
def total_frames(self) -> int:
|
||||
"""Total number of frames saved in this dataset."""
|
||||
return self.info["total_frames"]
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes selected."""
|
||||
|
@ -318,6 +335,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"""Total number of episodes available."""
|
||||
return self.info["total_episodes"]
|
||||
|
||||
@property
|
||||
def total_frames(self) -> int:
|
||||
"""Total number of frames saved in this dataset."""
|
||||
return self.info["total_frames"]
|
||||
|
||||
@property
|
||||
def total_tasks(self) -> int:
|
||||
"""Total number of different tasks performed in this dataset."""
|
||||
return self.info["total_tasks"]
|
||||
|
||||
@property
|
||||
def total_chunks(self) -> int:
|
||||
"""Total number of chunks (groups of episodes)."""
|
||||
|
@ -331,7 +358,46 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
@property
|
||||
def shapes(self) -> dict:
|
||||
"""Shapes for the different features."""
|
||||
self.info.get("shapes")
|
||||
return self.info["shapes"]
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
"""Shapes for the different features."""
|
||||
if self.hf_dataset is not None:
|
||||
return self.hf_dataset.features
|
||||
elif self.episode_buffer is None:
|
||||
raise NotImplementedError(
|
||||
"Dataset features must be infered from an existing hf_dataset or episode_buffer."
|
||||
)
|
||||
|
||||
features = {}
|
||||
for key in self.episode_buffer:
|
||||
if key in ["episode_index", "frame_index", "index", "task_index"]:
|
||||
features[key] = datasets.Value(dtype="int64")
|
||||
elif key in ["next.done", "next.success"]:
|
||||
features[key] = datasets.Value(dtype="bool")
|
||||
elif key in ["timestamp", "next.reward"]:
|
||||
features[key] = datasets.Value(dtype="float32")
|
||||
elif key in self.image_keys:
|
||||
features[key] = datasets.Image()
|
||||
elif key in self.keys:
|
||||
features[key] = datasets.Sequence(
|
||||
length=self.shapes[key], feature=datasets.Value(dtype="float32")
|
||||
)
|
||||
|
||||
return datasets.Features(features)
|
||||
|
||||
@property
|
||||
def task_to_task_index(self) -> dict:
|
||||
return {task: task_idx for task_idx, task in self.tasks.items()}
|
||||
|
||||
def get_task_index(self, task: str) -> int:
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
otherwise creates a new task_index.
|
||||
"""
|
||||
task_index = self.task_to_task_index.get(task, None)
|
||||
return task_index if task_index is not None else self.total_tasks
|
||||
|
||||
def current_episode_index(self, idx: int) -> int:
|
||||
episode_index = self.hf_dataset["episode_index"][idx]
|
||||
|
@ -447,12 +513,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
f")"
|
||||
)
|
||||
|
||||
def _create_episode_buffer(self) -> dict:
|
||||
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
# TODO(aliberts): Handle resume
|
||||
return {
|
||||
"chunk": self.total_chunks,
|
||||
"episode_index": self.total_episodes,
|
||||
"size": 0,
|
||||
"episode_index": self.total_episodes if episode_index is None else episode_index,
|
||||
"task_index": None,
|
||||
"frame_index": [],
|
||||
"timestamp": [],
|
||||
"next.done": [],
|
||||
|
@ -490,6 +556,92 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
file_path=img_path,
|
||||
)
|
||||
|
||||
def add_episode(self, task: str, encode_videos: bool = False) -> None:
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
|
||||
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
||||
the hub.
|
||||
|
||||
Use encode_videos if you want to encode videos during the saving of each episode. Otherwise,
|
||||
you can do it later during dataset.consolidate(). This is to give more flexibility on when to spend
|
||||
time for video encoding.
|
||||
"""
|
||||
episode_length = self.episode_buffer.pop("size")
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
task_index = self.get_task_index(task)
|
||||
self.episode_buffer["next.done"][-1] = True
|
||||
|
||||
for key in self.episode_buffer:
|
||||
if key in self.keys:
|
||||
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
||||
elif key == "episode_index":
|
||||
self.episode_buffer[key] = torch.full((episode_length,), episode_index)
|
||||
elif key == "task_index":
|
||||
self.episode_buffer[key] = torch.full((episode_length,), task_index)
|
||||
else:
|
||||
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
|
||||
|
||||
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
|
||||
self._save_episode_table(episode_index)
|
||||
|
||||
if encode_videos:
|
||||
pass # TODO
|
||||
|
||||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
self.consolidated = False
|
||||
|
||||
def _save_episode_table(self, episode_index: int) -> None:
|
||||
features = self.features
|
||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=features, split="train")
|
||||
ep_table = ep_dataset._data.table
|
||||
ep_data_path = self.get_data_file_path(ep_index=episode_index, return_str=False)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(ep_table, ep_data_path)
|
||||
|
||||
def _save_episode_to_metadata(
|
||||
self, episode_index: int, episode_length: int, task: str, task_index: int
|
||||
) -> None:
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
|
||||
if task_index not in self.tasks:
|
||||
self.info["total_tasks"] += 1
|
||||
self.tasks[task_index] = task
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonl(task_dict, self.root / "meta/tasks.jsonl")
|
||||
|
||||
chunk = self.get_episode_chunk(episode_index)
|
||||
if chunk >= self.total_chunks:
|
||||
self.info["total_chunks"] += 1
|
||||
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info["total_videos"] += len(self.video_keys)
|
||||
write_json(self.info, self.root / "meta/info.json")
|
||||
|
||||
episode_dict = {
|
||||
"episode_index": episode_index,
|
||||
"tasks": [task],
|
||||
"length": episode_length,
|
||||
}
|
||||
append_jsonl(episode_dict, self.root / "meta/episodes.jsonl")
|
||||
|
||||
def delete_episode(self) -> None:
|
||||
pass # TODO
|
||||
|
||||
def consolidate(self) -> None:
|
||||
pass # TODO
|
||||
# Sanity checks:
|
||||
# - [ ] shapes
|
||||
# - [ ] ep_lenghts
|
||||
# - [ ] number of files
|
||||
# - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002)
|
||||
# - [ ] no remaining self.image_writer.dir
|
||||
self.consolidated = True
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
|
@ -508,6 +660,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
obj._version = CODEBASE_VERSION
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_writer = image_writer
|
||||
obj.hf_dataset = None
|
||||
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warn(
|
||||
|
@ -515,12 +668,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
||||
)
|
||||
|
||||
obj.tasks = {}
|
||||
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
|
||||
write_json(obj.info, obj.root / "meta/info.json")
|
||||
|
||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||
obj.episode_buffer = obj._create_episode_buffer()
|
||||
|
||||
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk.
|
||||
# It is used to know when certain operations are need (for instance, computing dataset statistics).
|
||||
# In order to be able to push the dataset to the hub, it needs to be consolidation first.
|
||||
obj.consolidated = True
|
||||
|
||||
# obj.episodes = None
|
||||
# obj.image_transforms = None
|
||||
# obj.delta_timestamps = None
|
||||
|
|
|
@ -81,6 +81,11 @@ def write_json(data: dict, fpath: Path) -> None:
|
|||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def append_jsonl(data: dict, fpath: Path) -> None:
|
||||
with jsonlines.open(fpath, "a") as writer:
|
||||
writer.write(data)
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
|
|
|
@ -248,7 +248,7 @@ def control_loop(
|
|||
if teleoperate and policy is not None:
|
||||
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
||||
|
||||
if dataset is not None and fps is not None and dataset["fps"] != fps:
|
||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||
|
||||
timestamp = 0
|
||||
|
|
|
@ -109,8 +109,6 @@ from lerobot.common.datasets.image_writer import ImageWriter
|
|||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.populate_dataset import (
|
||||
create_lerobot_dataset,
|
||||
delete_current_episode,
|
||||
save_current_episode,
|
||||
)
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
control_loop,
|
||||
|
@ -195,6 +193,7 @@ def record(
|
|||
robot: Robot,
|
||||
root: str,
|
||||
repo_id: str,
|
||||
single_task: str,
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
policy_overrides: List[str] | None = None,
|
||||
fps: int | None = None,
|
||||
|
@ -219,6 +218,11 @@ def record(
|
|||
device = None
|
||||
use_amp = None
|
||||
|
||||
if single_task:
|
||||
task = single_task
|
||||
else:
|
||||
raise NotImplementedError("Only single-task recording is supported for now")
|
||||
|
||||
# Load pretrained policy
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||
|
@ -235,8 +239,8 @@ def record(
|
|||
sanity_check_dataset_name(repo_id, policy)
|
||||
image_writer = ImageWriter(
|
||||
write_dir=root,
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||
)
|
||||
dataset = LeRobotDataset.create(repo_id, fps, robot, image_writer=image_writer)
|
||||
|
||||
|
@ -261,7 +265,12 @@ def record(
|
|||
if recorded_episodes >= num_episodes:
|
||||
break
|
||||
|
||||
episode_index = dataset["num_episodes"]
|
||||
# TODO(aliberts): add task prompt for multitask here. Might need to temporarily disable event if
|
||||
# input() messes with them.
|
||||
# if multi_task:
|
||||
# task = input("Enter your task description: ")
|
||||
|
||||
episode_index = dataset.episode_buffer["episode_index"]
|
||||
log_say(f"Recording episode {episode_index}", play_sounds)
|
||||
record_episode(
|
||||
dataset=dataset,
|
||||
|
@ -289,11 +298,11 @@ def record(
|
|||
log_say("Re-record episode", play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
delete_current_episode(dataset)
|
||||
dataset.delete_episode()
|
||||
continue
|
||||
|
||||
# Increment by one dataset["current_episode_index"]
|
||||
save_current_episode(dataset)
|
||||
dataset.add_episode(task)
|
||||
|
||||
if events["stop_recording"]:
|
||||
break
|
||||
|
@ -378,9 +387,21 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||
task_args = parser_record.add_mutually_exclusive_group(required=True)
|
||||
parser_record.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
task_args.add_argument(
|
||||
"--single-task",
|
||||
type=str,
|
||||
help="A short but accurate description of the task performed during the recording.",
|
||||
)
|
||||
# TODO(aliberts): add multi-task support
|
||||
# task_args.add_argument(
|
||||
# "--multi-task",
|
||||
# type=int,
|
||||
# help="You will need to enter the task performed at the start of each episode.",
|
||||
# )
|
||||
parser_record.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
|
|
Loading…
Reference in New Issue