Add frame level task (#693)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Remi 2025-02-14 14:22:22 +01:00 committed by GitHub
parent d67ca342e9
commit 9d6886dd08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 105 additions and 50 deletions

View File

@ -180,6 +180,7 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
# Shift reward and success by +1 until the last item of the episode # Shift reward and success by +1 until the last item of the episode
"next.reward": reward[i + (frame_idx < num_frames - 1)], "next.reward": reward[i + (frame_idx < num_frames - 1)],
"next.success": success[i + (frame_idx < num_frames - 1)], "next.success": success[i + (frame_idx < num_frames - 1)],
"task": PUSHT_TASK,
} }
frame["observation.state"] = torch.from_numpy(agent_pos[i]) frame["observation.state"] = torch.from_numpy(agent_pos[i])
@ -191,7 +192,7 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
dataset.add_frame(frame) dataset.add_frame(frame)
dataset.save_episode(task=PUSHT_TASK) dataset.save_episode()
dataset.consolidate() dataset.consolidate()

View File

@ -87,7 +87,7 @@ class LeRobotDatasetMetadata:
self.pull_from_repo(allow_patterns="meta/") self.pull_from_repo(allow_patterns="meta/")
self.info = load_info(self.root) self.info = load_info(self.root)
self.stats = load_stats(self.root) self.stats = load_stats(self.root)
self.tasks = load_tasks(self.root) self.tasks, self.task_to_task_index = load_tasks(self.root)
self.episodes = load_episodes(self.root) self.episodes = load_episodes(self.root)
def pull_from_repo( def pull_from_repo(
@ -202,31 +202,35 @@ class LeRobotDatasetMetadata:
"""Max number of episodes per chunk.""" """Max number of episodes per chunk."""
return self.info["chunks_size"] return self.info["chunks_size"]
@property def get_task_index(self, task: str) -> int | None:
def task_to_task_index(self) -> dict:
return {task: task_idx for task_idx, task in self.tasks.items()}
def get_task_index(self, task: str) -> int:
""" """
Given a task in natural language, returns its task_index if the task already exists in the dataset, Given a task in natural language, returns its task_index if the task already exists in the dataset,
otherwise creates a new task_index. otherwise return None.
""" """
task_index = self.task_to_task_index.get(task, None) return self.task_to_task_index.get(task, None)
return task_index if task_index is not None else self.total_tasks
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None: def add_task(self, task: str):
self.info["total_episodes"] += 1 """
self.info["total_frames"] += episode_length Given a task in natural language, add it to the dictionnary of tasks.
"""
if task in self.task_to_task_index:
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
if task_index not in self.tasks: task_index = self.info["total_tasks"]
self.info["total_tasks"] += 1 self.task_to_task_index[task] = task_index
self.tasks[task_index] = task self.tasks[task_index] = task
self.info["total_tasks"] += 1
task_dict = { task_dict = {
"task_index": task_index, "task_index": task_index,
"task": task, "task": task,
} }
append_jsonlines(task_dict, self.root / TASKS_PATH) append_jsonlines(task_dict, self.root / TASKS_PATH)
def save_episode(self, episode_index: int, episode_length: int, episode_tasks: list[str]) -> None:
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
chunk = self.get_episode_chunk(episode_index) chunk = self.get_episode_chunk(episode_index)
if chunk >= self.total_chunks: if chunk >= self.total_chunks:
self.info["total_chunks"] += 1 self.info["total_chunks"] += 1
@ -237,7 +241,7 @@ class LeRobotDatasetMetadata:
episode_dict = { episode_dict = {
"episode_index": episode_index, "episode_index": episode_index,
"tasks": [task], "tasks": episode_tasks,
"length": episode_length, "length": episode_length,
} }
self.episodes.append(episode_dict) self.episodes.append(episode_dict)
@ -313,7 +317,8 @@ class LeRobotDatasetMetadata:
features = {**features, **DEFAULT_FEATURES} features = {**features, **DEFAULT_FEATURES}
obj.tasks, obj.stats, obj.episodes = {}, {}, [] obj.tasks, obj.task_to_task_index = {}, {}
obj.stats, obj.episodes = {}, []
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos) obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
if len(obj.video_keys) > 0 and not use_videos: if len(obj.video_keys) > 0 and not use_videos:
raise ValueError() raise ValueError()
@ -691,10 +696,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
def create_episode_buffer(self, episode_index: int | None = None) -> dict: def create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
return { ep_buffer = {}
"size": 0, # size and task are special cases that are not in self.features
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features}, ep_buffer["size"] = 0
} ep_buffer["task"] = []
for key in self.features:
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
return ep_buffer
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = DEFAULT_IMAGE_PATH.format( fpath = DEFAULT_IMAGE_PATH.format(
@ -718,6 +726,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
""" """
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch, # TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
# check the dtype and shape matches, etc. # check the dtype and shape matches, etc.
if "task" not in frame:
raise ValueError("The mandatory feature 'task' wasn't found in `frame` dictionnary.")
if self.episode_buffer is None: if self.episode_buffer is None:
self.episode_buffer = self.create_episode_buffer() self.episode_buffer = self.create_episode_buffer()
@ -728,13 +738,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer["timestamp"].append(timestamp) self.episode_buffer["timestamp"].append(timestamp)
for key in frame: for key in frame:
if key not in self.features: if key == "task":
raise ValueError(key) # Note: we associate the task in natural language to its task index during `save_episode`
self.episode_buffer["task"].append(frame["task"])
continue
if self.features[key]["dtype"] not in ["image", "video"]: if key not in self.features:
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key] raise ValueError(
self.episode_buffer[key].append(item) f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
elif self.features[key]["dtype"] in ["image", "video"]: )
if self.features[key]["dtype"] in ["image", "video"]:
img_path = self._get_image_file_path( img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
) )
@ -742,10 +756,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
img_path.parent.mkdir(parents=True, exist_ok=True) img_path.parent.mkdir(parents=True, exist_ok=True)
self._save_image(frame[key], img_path) self._save_image(frame[key], img_path)
self.episode_buffer[key].append(str(img_path)) self.episode_buffer[key].append(str(img_path))
else:
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
self.episode_buffer[key].append(item)
self.episode_buffer["size"] += 1 self.episode_buffer["size"] += 1
def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None: def save_episode(self, encode_videos: bool = True, episode_data: dict | None = None) -> None:
""" """
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
@ -758,7 +775,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
if not episode_data: if not episode_data:
episode_buffer = self.episode_buffer episode_buffer = self.episode_buffer
# size and task are special cases that won't be added to hf_dataset
episode_length = episode_buffer.pop("size") episode_length = episode_buffer.pop("size")
tasks = episode_buffer.pop("task")
episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"] episode_index = episode_buffer["episode_index"]
if episode_index != self.meta.total_episodes: if episode_index != self.meta.total_episodes:
# TODO(aliberts): Add option to use existing episode_index # TODO(aliberts): Add option to use existing episode_index
@ -772,21 +793,27 @@ class LeRobotDataset(torch.utils.data.Dataset):
"You must add one or several frames with `add_frame` before calling `add_episode`." "You must add one or several frames with `add_frame` before calling `add_episode`."
) )
task_index = self.meta.get_task_index(task)
if not set(episode_buffer.keys()) == set(self.features): if not set(episode_buffer.keys()) == set(self.features):
raise ValueError() raise ValueError(
f"Features from `episode_buffer` don't match the ones in `self.features`: '{set(episode_buffer.keys())}' vs '{set(self.features)}'"
)
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
# Add new tasks to the tasks dictionnary
for task in episode_tasks:
task_index = self.meta.get_task_index(task)
if task_index is None:
self.meta.add_task(task)
# Given tasks in natural language, find their corresponding task indices
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
for key, ft in self.features.items(): for key, ft in self.features.items():
if key == "index": # index, episode_index, task_index are already processed above, and image and video
episode_buffer[key] = np.arange( # are processed separately by storing image path and frame info as meta data
self.meta.total_frames, self.meta.total_frames + episode_length if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
)
elif key == "episode_index":
episode_buffer[key] = np.full((episode_length,), episode_index)
elif key == "task_index":
episode_buffer[key] = np.full((episode_length,), task_index)
elif ft["dtype"] in ["image", "video"]:
continue continue
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1: elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"]) episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
@ -798,7 +825,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._wait_image_writer() self._wait_image_writer()
self._save_episode_table(episode_buffer, episode_index) self._save_episode_table(episode_buffer, episode_index)
self.meta.save_episode(episode_index, episode_length, task, task_index) self.meta.save_episode(episode_index, episode_length, episode_tasks)
if encode_videos and len(self.meta.video_keys) > 0: if encode_videos and len(self.meta.video_keys) > 0:
video_paths = self.encode_episode_videos(episode_index) video_paths = self.encode_episode_videos(episode_index)

View File

@ -170,7 +170,9 @@ def load_stats(local_dir: Path) -> dict:
def load_tasks(local_dir: Path) -> dict: def load_tasks(local_dir: Path) -> dict:
tasks = load_jsonlines(local_dir / TASKS_PATH) tasks = load_jsonlines(local_dir / TASKS_PATH)
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
return tasks, task_to_task_index
def load_episodes(local_dir: Path) -> dict: def load_episodes(local_dir: Path) -> dict:

View File

@ -183,6 +183,7 @@ def record_episode(
device, device,
use_amp, use_amp,
fps, fps,
single_task,
): ):
control_loop( control_loop(
robot=robot, robot=robot,
@ -195,6 +196,7 @@ def record_episode(
use_amp=use_amp, use_amp=use_amp,
fps=fps, fps=fps,
teleoperate=policy is None, teleoperate=policy is None,
single_task=single_task,
) )
@ -210,6 +212,7 @@ def control_loop(
device: torch.device | str | None = None, device: torch.device | str | None = None,
use_amp: bool | None = None, use_amp: bool | None = None,
fps: int | None = None, fps: int | None = None,
single_task: str | None = None,
): ):
# TODO(rcadene): Add option to record logs # TODO(rcadene): Add option to record logs
if not robot.is_connected: if not robot.is_connected:
@ -224,6 +227,9 @@ def control_loop(
if teleoperate and policy is not None: if teleoperate and policy is not None:
raise ValueError("When `teleoperate` is True, `policy` should be None.") raise ValueError("When `teleoperate` is True, `policy` should be None.")
if dataset is not None and single_task is None:
raise ValueError("You need to provide a task as argument in `single_task`.")
if dataset is not None and fps is not None and dataset.fps != fps: if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
@ -248,7 +254,7 @@ def control_loop(
action = {"action": action} action = {"action": action}
if dataset is not None: if dataset is not None:
frame = {**observation, **action} frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame) dataset.add_frame(frame)
if display_cameras and not is_headless(): if display_cameras and not is_headless():

View File

@ -263,8 +263,8 @@ def record(
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
record_episode( record_episode(
dataset=dataset,
robot=robot, robot=robot,
dataset=dataset,
events=events, events=events,
episode_time_s=cfg.episode_time_s, episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras, display_cameras=cfg.display_cameras,
@ -272,6 +272,7 @@ def record(
device=cfg.device, device=cfg.device,
use_amp=cfg.use_amp, use_amp=cfg.use_amp,
fps=cfg.fps, fps=cfg.fps,
single_task=cfg.single_task,
) )
# Execute a few seconds without recording to give time to manually reset the environment # Execute a few seconds without recording to give time to manually reset the environment
@ -291,7 +292,7 @@ def record(
dataset.clear_episode_buffer() dataset.clear_episode_buffer()
continue continue
dataset.save_episode(cfg.single_task) dataset.save_episode()
recorded_episodes += 1 recorded_episodes += 1
if events["stop_recording"]: if events["stop_recording"]:

View File

@ -93,6 +93,24 @@ def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
assert dataset.num_frames == len(dataset) assert dataset.num_frames == len(dataset)
def test_add_frame_no_task(tmp_path):
features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features)
with pytest.raises(ValueError, match="The mandatory feature 'task' wasn't found in `frame` dictionnary."):
dataset.add_frame({"1d": torch.randn(1)})
def test_add_frame(tmp_path):
features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features)
dataset.add_frame({"1d": torch.randn(1), "task": "dummy"})
dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False)
assert len(dataset) == 1
assert dataset[0]["task"] == "dummy"
assert dataset[0]["task_index"] == 0
# TODO(aliberts): # TODO(aliberts):
# - [ ] test various attributes & state from init and create # - [ ] test various attributes & state from init and create
# - [ ] test init with episodes and check num_frames # - [ ] test init with episodes and check num_frames