Add frame level task (#693)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
parent
d67ca342e9
commit
9d6886dd08
|
@ -180,6 +180,7 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
|||
# Shift reward and success by +1 until the last item of the episode
|
||||
"next.reward": reward[i + (frame_idx < num_frames - 1)],
|
||||
"next.success": success[i + (frame_idx < num_frames - 1)],
|
||||
"task": PUSHT_TASK,
|
||||
}
|
||||
|
||||
frame["observation.state"] = torch.from_numpy(agent_pos[i])
|
||||
|
@ -191,7 +192,7 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
|||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode(task=PUSHT_TASK)
|
||||
dataset.save_episode()
|
||||
|
||||
dataset.consolidate()
|
||||
|
||||
|
|
|
@ -87,7 +87,7 @@ class LeRobotDatasetMetadata:
|
|||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.info = load_info(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
|
||||
def pull_from_repo(
|
||||
|
@ -202,31 +202,35 @@ class LeRobotDatasetMetadata:
|
|||
"""Max number of episodes per chunk."""
|
||||
return self.info["chunks_size"]
|
||||
|
||||
@property
|
||||
def task_to_task_index(self) -> dict:
|
||||
return {task: task_idx for task_idx, task in self.tasks.items()}
|
||||
|
||||
def get_task_index(self, task: str) -> int:
|
||||
def get_task_index(self, task: str) -> int | None:
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
otherwise creates a new task_index.
|
||||
otherwise return None.
|
||||
"""
|
||||
task_index = self.task_to_task_index.get(task, None)
|
||||
return task_index if task_index is not None else self.total_tasks
|
||||
return self.task_to_task_index.get(task, None)
|
||||
|
||||
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
|
||||
def add_task(self, task: str):
|
||||
"""
|
||||
Given a task in natural language, add it to the dictionnary of tasks.
|
||||
"""
|
||||
if task in self.task_to_task_index:
|
||||
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
|
||||
|
||||
task_index = self.info["total_tasks"]
|
||||
self.task_to_task_index[task] = task_index
|
||||
self.tasks[task_index] = task
|
||||
self.info["total_tasks"] += 1
|
||||
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
||||
|
||||
def save_episode(self, episode_index: int, episode_length: int, episode_tasks: list[str]) -> None:
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
|
||||
if task_index not in self.tasks:
|
||||
self.info["total_tasks"] += 1
|
||||
self.tasks[task_index] = task
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
||||
|
||||
chunk = self.get_episode_chunk(episode_index)
|
||||
if chunk >= self.total_chunks:
|
||||
self.info["total_chunks"] += 1
|
||||
|
@ -237,7 +241,7 @@ class LeRobotDatasetMetadata:
|
|||
|
||||
episode_dict = {
|
||||
"episode_index": episode_index,
|
||||
"tasks": [task],
|
||||
"tasks": episode_tasks,
|
||||
"length": episode_length,
|
||||
}
|
||||
self.episodes.append(episode_dict)
|
||||
|
@ -313,7 +317,8 @@ class LeRobotDatasetMetadata:
|
|||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
obj.tasks, obj.stats, obj.episodes = {}, {}, []
|
||||
obj.tasks, obj.task_to_task_index = {}, {}
|
||||
obj.stats, obj.episodes = {}, []
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
|
@ -691,10 +696,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
return {
|
||||
"size": 0,
|
||||
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
|
||||
}
|
||||
ep_buffer = {}
|
||||
# size and task are special cases that are not in self.features
|
||||
ep_buffer["size"] = 0
|
||||
ep_buffer["task"] = []
|
||||
for key in self.features:
|
||||
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
|
||||
return ep_buffer
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
|
@ -718,6 +726,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"""
|
||||
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
|
||||
# check the dtype and shape matches, etc.
|
||||
if "task" not in frame:
|
||||
raise ValueError("The mandatory feature 'task' wasn't found in `frame` dictionnary.")
|
||||
|
||||
if self.episode_buffer is None:
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
@ -728,13 +738,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
|
||||
for key in frame:
|
||||
if key not in self.features:
|
||||
raise ValueError(key)
|
||||
if key == "task":
|
||||
# Note: we associate the task in natural language to its task index during `save_episode`
|
||||
self.episode_buffer["task"].append(frame["task"])
|
||||
continue
|
||||
|
||||
if self.features[key]["dtype"] not in ["image", "video"]:
|
||||
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
|
||||
self.episode_buffer[key].append(item)
|
||||
elif self.features[key]["dtype"] in ["image", "video"]:
|
||||
if key not in self.features:
|
||||
raise ValueError(
|
||||
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
|
||||
)
|
||||
|
||||
if self.features[key]["dtype"] in ["image", "video"]:
|
||||
img_path = self._get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||
)
|
||||
|
@ -742,10 +756,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._save_image(frame[key], img_path)
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
else:
|
||||
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
|
||||
self.episode_buffer[key].append(item)
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None:
|
||||
def save_episode(self, encode_videos: bool = True, episode_data: dict | None = None) -> None:
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
|
||||
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
||||
|
@ -758,7 +775,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
if not episode_data:
|
||||
episode_buffer = self.episode_buffer
|
||||
|
||||
# size and task are special cases that won't be added to hf_dataset
|
||||
episode_length = episode_buffer.pop("size")
|
||||
tasks = episode_buffer.pop("task")
|
||||
episode_tasks = list(set(tasks))
|
||||
|
||||
episode_index = episode_buffer["episode_index"]
|
||||
if episode_index != self.meta.total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
|
@ -772,21 +793,27 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"You must add one or several frames with `add_frame` before calling `add_episode`."
|
||||
)
|
||||
|
||||
task_index = self.meta.get_task_index(task)
|
||||
|
||||
if not set(episode_buffer.keys()) == set(self.features):
|
||||
raise ValueError()
|
||||
raise ValueError(
|
||||
f"Features from `episode_buffer` don't match the ones in `self.features`: '{set(episode_buffer.keys())}' vs '{set(self.features)}'"
|
||||
)
|
||||
|
||||
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||
|
||||
# Add new tasks to the tasks dictionnary
|
||||
for task in episode_tasks:
|
||||
task_index = self.meta.get_task_index(task)
|
||||
if task_index is None:
|
||||
self.meta.add_task(task)
|
||||
|
||||
# Given tasks in natural language, find their corresponding task indices
|
||||
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
||||
|
||||
for key, ft in self.features.items():
|
||||
if key == "index":
|
||||
episode_buffer[key] = np.arange(
|
||||
self.meta.total_frames, self.meta.total_frames + episode_length
|
||||
)
|
||||
elif key == "episode_index":
|
||||
episode_buffer[key] = np.full((episode_length,), episode_index)
|
||||
elif key == "task_index":
|
||||
episode_buffer[key] = np.full((episode_length,), task_index)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
# index, episode_index, task_index are already processed above, and image and video
|
||||
# are processed separately by storing image path and frame info as meta data
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
|
||||
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
|
||||
|
@ -798,7 +825,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self._wait_image_writer()
|
||||
self._save_episode_table(episode_buffer, episode_index)
|
||||
|
||||
self.meta.save_episode(episode_index, episode_length, task, task_index)
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks)
|
||||
|
||||
if encode_videos and len(self.meta.video_keys) > 0:
|
||||
video_paths = self.encode_episode_videos(episode_index)
|
||||
|
|
|
@ -170,7 +170,9 @@ def load_stats(local_dir: Path) -> dict:
|
|||
|
||||
def load_tasks(local_dir: Path) -> dict:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||
return tasks, task_to_task_index
|
||||
|
||||
|
||||
def load_episodes(local_dir: Path) -> dict:
|
||||
|
|
|
@ -183,6 +183,7 @@ def record_episode(
|
|||
device,
|
||||
use_amp,
|
||||
fps,
|
||||
single_task,
|
||||
):
|
||||
control_loop(
|
||||
robot=robot,
|
||||
|
@ -195,6 +196,7 @@ def record_episode(
|
|||
use_amp=use_amp,
|
||||
fps=fps,
|
||||
teleoperate=policy is None,
|
||||
single_task=single_task,
|
||||
)
|
||||
|
||||
|
||||
|
@ -210,6 +212,7 @@ def control_loop(
|
|||
device: torch.device | str | None = None,
|
||||
use_amp: bool | None = None,
|
||||
fps: int | None = None,
|
||||
single_task: str | None = None,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
|
@ -224,6 +227,9 @@ def control_loop(
|
|||
if teleoperate and policy is not None:
|
||||
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
||||
|
||||
if dataset is not None and single_task is None:
|
||||
raise ValueError("You need to provide a task as argument in `single_task`.")
|
||||
|
||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||
|
||||
|
@ -248,7 +254,7 @@ def control_loop(
|
|||
action = {"action": action}
|
||||
|
||||
if dataset is not None:
|
||||
frame = {**observation, **action}
|
||||
frame = {**observation, **action, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
|
|
|
@ -263,8 +263,8 @@ def record(
|
|||
|
||||
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
record_episode(
|
||||
dataset=dataset,
|
||||
robot=robot,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
episode_time_s=cfg.episode_time_s,
|
||||
display_cameras=cfg.display_cameras,
|
||||
|
@ -272,6 +272,7 @@ def record(
|
|||
device=cfg.device,
|
||||
use_amp=cfg.use_amp,
|
||||
fps=cfg.fps,
|
||||
single_task=cfg.single_task,
|
||||
)
|
||||
|
||||
# Execute a few seconds without recording to give time to manually reset the environment
|
||||
|
@ -291,7 +292,7 @@ def record(
|
|||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode(cfg.single_task)
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
|
||||
if events["stop_recording"]:
|
||||
|
|
|
@ -93,6 +93,24 @@ def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
|
|||
assert dataset.num_frames == len(dataset)
|
||||
|
||||
|
||||
def test_add_frame_no_task(tmp_path):
|
||||
features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features)
|
||||
with pytest.raises(ValueError, match="The mandatory feature 'task' wasn't found in `frame` dictionnary."):
|
||||
dataset.add_frame({"1d": torch.randn(1)})
|
||||
|
||||
|
||||
def test_add_frame(tmp_path):
|
||||
features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"1d": torch.randn(1), "task": "dummy"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate(run_compute_stats=False)
|
||||
assert len(dataset) == 1
|
||||
assert dataset[0]["task"] == "dummy"
|
||||
assert dataset[0]["task_index"] == 0
|
||||
|
||||
|
||||
# TODO(aliberts):
|
||||
# - [ ] test various attributes & state from init and create
|
||||
# - [ ] test init with episodes and check num_frames
|
||||
|
|
Loading…
Reference in New Issue