Add frame level task (#693)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Remi 2025-02-14 14:22:22 +01:00 committed by GitHub
parent d67ca342e9
commit 9d6886dd08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 105 additions and 50 deletions

View File

@ -180,6 +180,7 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
# Shift reward and success by +1 until the last item of the episode
"next.reward": reward[i + (frame_idx < num_frames - 1)],
"next.success": success[i + (frame_idx < num_frames - 1)],
"task": PUSHT_TASK,
}
frame["observation.state"] = torch.from_numpy(agent_pos[i])
@ -191,7 +192,7 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
dataset.add_frame(frame)
dataset.save_episode(task=PUSHT_TASK)
dataset.save_episode()
dataset.consolidate()

View File

@ -87,7 +87,7 @@ class LeRobotDatasetMetadata:
self.pull_from_repo(allow_patterns="meta/")
self.info = load_info(self.root)
self.stats = load_stats(self.root)
self.tasks = load_tasks(self.root)
self.tasks, self.task_to_task_index = load_tasks(self.root)
self.episodes = load_episodes(self.root)
def pull_from_repo(
@ -202,31 +202,35 @@ class LeRobotDatasetMetadata:
"""Max number of episodes per chunk."""
return self.info["chunks_size"]
@property
def task_to_task_index(self) -> dict:
return {task: task_idx for task_idx, task in self.tasks.items()}
def get_task_index(self, task: str) -> int:
def get_task_index(self, task: str) -> int | None:
"""
Given a task in natural language, returns its task_index if the task already exists in the dataset,
otherwise creates a new task_index.
otherwise return None.
"""
task_index = self.task_to_task_index.get(task, None)
return task_index if task_index is not None else self.total_tasks
return self.task_to_task_index.get(task, None)
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
def add_task(self, task: str):
"""
Given a task in natural language, add it to the dictionnary of tasks.
"""
if task in self.task_to_task_index:
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
task_index = self.info["total_tasks"]
self.task_to_task_index[task] = task_index
self.tasks[task_index] = task
self.info["total_tasks"] += 1
task_dict = {
"task_index": task_index,
"task": task,
}
append_jsonlines(task_dict, self.root / TASKS_PATH)
def save_episode(self, episode_index: int, episode_length: int, episode_tasks: list[str]) -> None:
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
if task_index not in self.tasks:
self.info["total_tasks"] += 1
self.tasks[task_index] = task
task_dict = {
"task_index": task_index,
"task": task,
}
append_jsonlines(task_dict, self.root / TASKS_PATH)
chunk = self.get_episode_chunk(episode_index)
if chunk >= self.total_chunks:
self.info["total_chunks"] += 1
@ -237,7 +241,7 @@ class LeRobotDatasetMetadata:
episode_dict = {
"episode_index": episode_index,
"tasks": [task],
"tasks": episode_tasks,
"length": episode_length,
}
self.episodes.append(episode_dict)
@ -313,7 +317,8 @@ class LeRobotDatasetMetadata:
features = {**features, **DEFAULT_FEATURES}
obj.tasks, obj.stats, obj.episodes = {}, {}, []
obj.tasks, obj.task_to_task_index = {}, {}
obj.stats, obj.episodes = {}, []
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
@ -691,10 +696,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
return {
"size": 0,
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
}
ep_buffer = {}
# size and task are special cases that are not in self.features
ep_buffer["size"] = 0
ep_buffer["task"] = []
for key in self.features:
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
return ep_buffer
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = DEFAULT_IMAGE_PATH.format(
@ -718,6 +726,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
# check the dtype and shape matches, etc.
if "task" not in frame:
raise ValueError("The mandatory feature 'task' wasn't found in `frame` dictionnary.")
if self.episode_buffer is None:
self.episode_buffer = self.create_episode_buffer()
@ -728,13 +738,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer["timestamp"].append(timestamp)
for key in frame:
if key not in self.features:
raise ValueError(key)
if key == "task":
# Note: we associate the task in natural language to its task index during `save_episode`
self.episode_buffer["task"].append(frame["task"])
continue
if self.features[key]["dtype"] not in ["image", "video"]:
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
self.episode_buffer[key].append(item)
elif self.features[key]["dtype"] in ["image", "video"]:
if key not in self.features:
raise ValueError(
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
)
if self.features[key]["dtype"] in ["image", "video"]:
img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
)
@ -742,10 +756,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
img_path.parent.mkdir(parents=True, exist_ok=True)
self._save_image(frame[key], img_path)
self.episode_buffer[key].append(str(img_path))
else:
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
self.episode_buffer[key].append(item)
self.episode_buffer["size"] += 1
def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None:
def save_episode(self, encode_videos: bool = True, episode_data: dict | None = None) -> None:
"""
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
@ -758,7 +775,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
if not episode_data:
episode_buffer = self.episode_buffer
# size and task are special cases that won't be added to hf_dataset
episode_length = episode_buffer.pop("size")
tasks = episode_buffer.pop("task")
episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"]
if episode_index != self.meta.total_episodes:
# TODO(aliberts): Add option to use existing episode_index
@ -772,21 +793,27 @@ class LeRobotDataset(torch.utils.data.Dataset):
"You must add one or several frames with `add_frame` before calling `add_episode`."
)
task_index = self.meta.get_task_index(task)
if not set(episode_buffer.keys()) == set(self.features):
raise ValueError()
raise ValueError(
f"Features from `episode_buffer` don't match the ones in `self.features`: '{set(episode_buffer.keys())}' vs '{set(self.features)}'"
)
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
# Add new tasks to the tasks dictionnary
for task in episode_tasks:
task_index = self.meta.get_task_index(task)
if task_index is None:
self.meta.add_task(task)
# Given tasks in natural language, find their corresponding task indices
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
for key, ft in self.features.items():
if key == "index":
episode_buffer[key] = np.arange(
self.meta.total_frames, self.meta.total_frames + episode_length
)
elif key == "episode_index":
episode_buffer[key] = np.full((episode_length,), episode_index)
elif key == "task_index":
episode_buffer[key] = np.full((episode_length,), task_index)
elif ft["dtype"] in ["image", "video"]:
# index, episode_index, task_index are already processed above, and image and video
# are processed separately by storing image path and frame info as meta data
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
@ -798,7 +825,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._wait_image_writer()
self._save_episode_table(episode_buffer, episode_index)
self.meta.save_episode(episode_index, episode_length, task, task_index)
self.meta.save_episode(episode_index, episode_length, episode_tasks)
if encode_videos and len(self.meta.video_keys) > 0:
video_paths = self.encode_episode_videos(episode_index)

View File

@ -170,7 +170,9 @@ def load_stats(local_dir: Path) -> dict:
def load_tasks(local_dir: Path) -> dict:
tasks = load_jsonlines(local_dir / TASKS_PATH)
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
return tasks, task_to_task_index
def load_episodes(local_dir: Path) -> dict:

View File

@ -183,6 +183,7 @@ def record_episode(
device,
use_amp,
fps,
single_task,
):
control_loop(
robot=robot,
@ -195,6 +196,7 @@ def record_episode(
use_amp=use_amp,
fps=fps,
teleoperate=policy is None,
single_task=single_task,
)
@ -210,6 +212,7 @@ def control_loop(
device: torch.device | str | None = None,
use_amp: bool | None = None,
fps: int | None = None,
single_task: str | None = None,
):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
@ -224,6 +227,9 @@ def control_loop(
if teleoperate and policy is not None:
raise ValueError("When `teleoperate` is True, `policy` should be None.")
if dataset is not None and single_task is None:
raise ValueError("You need to provide a task as argument in `single_task`.")
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
@ -248,7 +254,7 @@ def control_loop(
action = {"action": action}
if dataset is not None:
frame = {**observation, **action}
frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)
if display_cameras and not is_headless():

View File

@ -263,8 +263,8 @@ def record(
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
record_episode(
dataset=dataset,
robot=robot,
dataset=dataset,
events=events,
episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras,
@ -272,6 +272,7 @@ def record(
device=cfg.device,
use_amp=cfg.use_amp,
fps=cfg.fps,
single_task=cfg.single_task,
)
# Execute a few seconds without recording to give time to manually reset the environment
@ -291,7 +292,7 @@ def record(
dataset.clear_episode_buffer()
continue
dataset.save_episode(cfg.single_task)
dataset.save_episode()
recorded_episodes += 1
if events["stop_recording"]:

View File

@ -93,6 +93,24 @@ def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
assert dataset.num_frames == len(dataset)
def test_add_frame_no_task(tmp_path):
features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features)
with pytest.raises(ValueError, match="The mandatory feature 'task' wasn't found in `frame` dictionnary."):
dataset.add_frame({"1d": torch.randn(1)})
def test_add_frame(tmp_path):
features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features)
dataset.add_frame({"1d": torch.randn(1), "task": "dummy"})
dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False)
assert len(dataset) == 1
assert dataset[0]["task"] == "dummy"
assert dataset[0]["task_index"] == 0
# TODO(aliberts):
# - [ ] test various attributes & state from init and create
# - [ ] test init with episodes and check num_frames