Improve dataset v2 (#498)

This commit is contained in:
Remi 2024-11-19 12:31:47 +01:00 committed by GitHub
parent acae4b49d2
commit 1f13bda25b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 393 additions and 70 deletions

View File

@ -0,0 +1,246 @@
import shutil
from pathlib import Path
import numpy as np
import torch
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
def create_empty_dataset(repo_id, mode):
features = {
"observation.state": {
"dtype": "float32",
"shape": (2,),
"names": [
["x", "y"],
],
},
"action": {
"dtype": "float32",
"shape": (2,),
"names": [
["x", "y"],
],
},
"next.reward": {
"dtype": "float32",
"shape": (1,),
"names": None,
},
"next.success": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
}
if mode == "keypoints":
features["observation.environment_state"] = {
"dtype": "float32",
"shape": (16,),
"names": [
"keypoints",
],
}
else:
features["observation.image"] = {
"dtype": mode,
"shape": (3, 96, 96),
"names": [
"channel",
"height",
"width",
],
}
dataset = LeRobotDataset.create(
repo_id=repo_id,
fps=10,
robot_type="2d pointer",
features=features,
image_writer_threads=4,
)
return dataset
def load_raw_dataset(zarr_path, load_images=True):
try:
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
ReplayBuffer as DiffusionPolicyReplayBuffer,
)
except ModuleNotFoundError as e:
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
raise e
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
env_state = zarr_data["state"][:]
agent_pos = env_state[:, :2]
block_pos = env_state[:, 2:4]
block_angle = env_state[:, 4]
action = zarr_data["action"][:]
image = None
if load_images:
# b h w c
image = zarr_data["img"]
episode_data_index = {
"from": np.array([0] + zarr_data.meta["episode_ends"][:-1].tolist()),
"to": zarr_data.meta["episode_ends"],
}
return image, agent_pos, block_pos, block_angle, action, episode_data_index
def calculate_coverage(block_pos, block_angle):
try:
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
except ModuleNotFoundError as e:
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
raise e
num_frames = len(block_pos)
coverage = np.zeros((num_frames,))
# 8 keypoints with 2 coords each
keypoints = np.zeros((num_frames, 16))
# Set x, y, theta (in radians)
goal_pos_angle = np.array([256, 256, np.pi / 4])
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
for i in range(num_frames):
space = pymunk.Space()
space.gravity = 0, 0
space.damping = 0
# Add walls.
walls = [
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
]
space.add(*walls)
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
goal_area = goal_geom.area
coverage[i] = intersection_area / goal_area
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
return coverage, keypoints
def calculate_success(coverage, success_threshold):
return coverage > success_threshold
def calculate_reward(coverage, success_threshold):
return np.clip(coverage / success_threshold, 0, 1)
def populate_dataset(dataset, episode_data_index, episodes, image, state, env_state, action, reward, success):
if episodes is None:
episodes = range(len(episode_data_index["from"]))
for ep_idx in episodes:
from_idx = episode_data_index["from"][ep_idx]
to_idx = episode_data_index["to"][ep_idx]
num_frames = to_idx - from_idx
for frame_idx in range(num_frames):
i = from_idx + frame_idx
frame = {
"action": torch.from_numpy(action[i]),
# Shift reward and success by +1 until the last item of the episode
"next.reward": reward[i + (frame_idx < num_frames - 1)],
"next.success": success[i + (frame_idx < num_frames - 1)],
}
frame["observation.state"] = torch.from_numpy(state[i])
if env_state is not None:
frame["observation.environment_state"] = torch.from_numpy(env_state[i])
if image is not None:
frame["observation.image"] = torch.from_numpy(image[i])
dataset.add_frame(frame)
dataset.save_episode(task="Push the T-shaped blue block onto the T-shaped green target surface.")
return dataset
def port_pusht(raw_dir, repo_id, episodes=None, mode="video", push_to_hub=True):
if mode not in ["video", "image", "keypoints"]:
raise ValueError(mode)
if (LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
raw_dir = Path(raw_dir)
if not raw_dir.exists():
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
image, agent_pos, block_pos, block_angle, action, episode_data_index = load_raw_dataset(
zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr"
)
# Calculate success and reward based on the overlapping area
# of the T-object and the T-area.
coverage, keypoints = calculate_coverage(block_pos, block_angle)
success = calculate_success(coverage, success_threshold=0.95)
reward = calculate_reward(coverage, success_threshold=0.95)
dataset = create_empty_dataset(repo_id, mode)
dataset = populate_dataset(
dataset,
episode_data_index,
episodes,
image=None if mode == "keypoints" else image,
state=agent_pos,
env_state=keypoints if mode == "keypoints" else None,
action=action,
reward=reward,
success=success,
)
dataset.consolidate()
if push_to_hub:
dataset.push_to_hub()
if __name__ == "__main__":
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
repo_id = "lerobot/pusht"
episodes = None
# Uncomment if you want to try with a subset (episode 0 and 1)
# episodes = [0, 1]
modes = ["video", "image", "keypoints"]
# Uncomment if you want to try with a specific mode
# modes = ["video"]
# modes = ["image"]
# modes = ["keypoints"]
for mode in ["video", "image", "keypoints"]:
if mode in ["image", "keypoints"]:
repo_id += f"_{mode}"
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
port_pusht("data/lerobot-raw/pusht_raw", repo_id=repo_id, mode=mode, episodes=episodes)
# Uncomment if you want to loal the local dataset and explore it
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
# breakpoint()

View File

@ -280,6 +280,8 @@ class LeRobotDatasetMetadata:
obj.repo_id = repo_id
obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
obj.root.mkdir(parents=True, exist_ok=False)
if robot is not None:
features = get_features_from_robot(robot, use_videos)
robot_type = robot.robot_type
@ -293,6 +295,7 @@ class LeRobotDatasetMetadata:
"Dataset features must either come from a Robot or explicitly passed upon creation."
)
else:
# TODO(aliberts, rcadene): implement sanity check for features
features = {**features, **DEFAULT_FEATURES}
obj.tasks, obj.stats, obj.episodes = {}, {}, []
@ -424,11 +427,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.video_backend = video_backend if video_backend is not None else "pyav"
self.delta_indices = None
self.local_files_only = local_files_only
self.consolidated = True
# Unused attributes
self.image_writer = None
self.episode_buffer = {}
self.episode_buffer = None
self.root.mkdir(exist_ok=True, parents=True)
@ -451,12 +453,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# Available stats implies all videos have been encoded and dataset is iterable
self.consolidated = self.meta.stats is not None
def push_to_hub(
self,
tags: list | None = None,
text: str | None = None,
license: str | None = "mit",
license: str | None = "apache-2.0",
push_videos: bool = True,
private: bool = False,
) -> None:
if not self.consolidated:
raise RuntimeError(
@ -468,7 +474,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
if not push_videos:
ignore_patterns.append("videos/")
create_repo(self.repo_id, repo_type="dataset", exist_ok=True)
create_repo(
repo_id=self.repo_id,
private=private,
repo_type="dataset",
exist_ok=True,
)
upload_folder(
repo_id=self.repo_id,
folder_path=self.root,
@ -658,7 +670,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
return {
"size": 0,
**{key: [] if key != "episode_index" else current_ep_idx for key in self.features},
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
}
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
@ -681,8 +693,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
temporary directory nothing is written to disk. To save those frames, the 'save_episode()' method
then needs to be called.
"""
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
# check the dtype and shape matches, etc.
if self.episode_buffer is None:
self.episode_buffer = self._create_episode_buffer()
frame_index = self.episode_buffer["size"]
timestamp = frame["timestamp"] if "timestamp" in frame else frame_index / self.fps
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp)
@ -723,6 +741,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError()
if episode_length == 0:
raise ValueError(
"You must add one or several frames with `add_frame` before calling `add_episode`."
)
task_index = self.meta.get_task_index(task)
if not set(episode_buffer.keys()) == set(self.features):
@ -781,7 +804,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Reset the buffer
self.episode_buffer = self._create_episode_buffer()
def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None:
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
if isinstance(self.image_writer, AsyncImageWriter):
logging.warning(
"You are starting a new AsyncImageWriter that is replacing an already exising one in the dataset."

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import textwrap
import warnings
from itertools import accumulate
from pathlib import Path
@ -139,6 +140,8 @@ def load_info(local_dir: Path) -> dict:
def load_stats(local_dir: Path) -> dict:
if not (local_dir / STATS_PATH).exists():
return None
stats = load_json(local_dir / STATS_PATH)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
@ -186,17 +189,37 @@ def _get_major_minor(version: str) -> tuple[int]:
return int(split[0]), int(split[1])
class BackwardCompatibilityError(Exception):
def __init__(self, repo_id, version):
message = textwrap.dedent(f"""
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
We introduced a new format since v2.0 which is not backward compatible with v1.x.
Please, use our conversion script. Modify the following command with your own task description:
```
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
--repo-id {repo_id} \\
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
```
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
If you encounter a problem, contact LeRobot maintainers on Discord ('https://discord.com/invite/s3KuuzsPFb')
or open an issue on GitHub.
""")
super().__init__(message)
def check_version_compatibility(
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
) -> None:
current_major, _ = _get_major_minor(current_version)
major_to_check, _ = _get_major_minor(version_to_check)
if major_to_check < current_major and enforce_breaking_major:
raise ValueError(
f"""The dataset you requested ({repo_id}) is in {version_to_check} format. We introduced a new
format with v2.0 that is not backward compatible. Please use our conversion script
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
)
raise BackwardCompatibilityError(repo_id, version_to_check)
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
warnings.warn(
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
@ -207,18 +230,16 @@ def check_version_compatibility(
)
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
num_version = float(version.strip("v"))
if num_version < 2 and enforce_v2:
raise ValueError(
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
format with v2.0 that is not backward compatible. Please use our conversion script
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
)
def get_hub_safe_version(repo_id: str, version: str) -> str:
api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if version not in branches:
num_version = float(version.strip("v"))
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
raise BackwardCompatibilityError(repo_id, version)
warnings.warn(
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
codebase. The following versions are available: {branches}.
@ -461,6 +482,7 @@ def create_lerobot_dataset_card(
}
]
card.data.task_categories = ["robotics"]
card.data.license = license
card.data.tags = ["LeRobot"]
if license:
card.data.license = license

View File

@ -441,7 +441,7 @@ def convert_dataset(
arxiv: str | None = None,
test_branch: str | None = None,
):
v1 = get_hub_safe_version(repo_id, V16, enforce_v2=False)
v1 = get_hub_safe_version(repo_id, V16)
v1x_dir = local_dir / V16 / repo_id
v20_dir = local_dir / V20 / repo_id
v1x_dir.mkdir(parents=True, exist_ok=True)

View File

@ -17,6 +17,7 @@ from termcolor import colored
from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import get_features_from_robot
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
@ -330,3 +331,21 @@ def sanity_check_dataset_name(repo_id, policy):
raise ValueError(
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
)
def sanity_check_dataset_robot_compatibility(dataset, robot, fps, use_videos):
fields = [
("robot_type", dataset.meta.info["robot_type"], robot.robot_type),
("fps", dataset.meta.info["fps"], fps),
("features", dataset.features, get_features_from_robot(robot, use_videos)),
]
mismatches = []
for field, dataset_value, present_value in fields:
if dataset_value != present_value:
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
if mismatches:
raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
)

View File

@ -115,6 +115,7 @@ from lerobot.common.robot_devices.control_utils import (
record_episode,
reset_environment,
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
stop_recording,
warmup_record,
)
@ -207,6 +208,9 @@ def record(
num_image_writer_threads_per_camera: int = 4,
display_cameras: bool = True,
play_sounds: bool = True,
resume: bool = False,
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
local_files_only: bool = False,
) -> LeRobotDataset:
# TODO(rcadene): Add option to record logs
listener = None
@ -232,17 +236,29 @@ def record(
f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})."
)
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
dataset = LeRobotDataset.create(
repo_id,
fps,
root=root,
robot=robot,
use_videos=video,
image_writer_processes=num_image_writer_processes,
image_writer_threads=num_image_writer_threads_per_camera,
)
if resume:
dataset = LeRobotDataset(
repo_id,
root=root,
local_files_only=local_files_only,
)
dataset.start_image_writer(
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
)
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
else:
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
dataset = LeRobotDataset.create(
repo_id,
fps,
root=root,
robot=robot,
use_videos=video,
image_writer_processes=num_image_writer_processes,
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
)
if not robot.is_connected:
robot.connect()
@ -270,8 +286,7 @@ def record(
# if multi_task:
# task = input("Enter your task description: ")
episode_index = dataset.episode_buffer["episode_index"]
log_say(f"Recording episode {episode_index}", play_sounds)
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
record_episode(
dataset=dataset,
robot=robot,
@ -289,7 +304,7 @@ def record(
# TODO(rcadene): add an option to enable teleoperation during reset
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(episode_index < num_episodes - 1) or events["rerecord_episode"]
(dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", play_sounds)
reset_environment(robot, events, reset_time_s)

View File

@ -117,10 +117,14 @@ def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str
def push_dataset_card_to_hub(
repo_id: str, revision: str | None, tags: list | None = None, text: str | None = None
repo_id: str,
revision: str | None,
tags: list | None = None,
text: str | None = None,
license: str = "apache-2.0",
):
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
card = create_lerobot_dataset_card(tags=tags, text=text)
card = create_lerobot_dataset_card(tags=tags, text=text, license=license)
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)

View File

@ -325,7 +325,7 @@ def lerobot_dataset_metadata_factory(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)

View File

@ -275,22 +275,25 @@ def test_resume_record(tmpdir, request, robot_type, mock):
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
dataset = record(
robot,
root,
repo_id,
single_task,
fps=1,
warmup_time_s=0,
episode_time_s=1,
num_episodes=1,
push_to_hub=False,
video=False,
display_cameras=False,
play_sounds=False,
run_compute_stats=False,
)
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
record_kwargs = {
"robot": robot,
"root": root,
"repo_id": repo_id,
"single_task": single_task,
"fps": 1,
"warmup_time_s": 0,
"episode_time_s": 1,
"push_to_hub": False,
"video": False,
"display_cameras": False,
"play_sounds": False,
"run_compute_stats": False,
"local_files_only": True,
"num_episodes": 1,
}
dataset = record(**record_kwargs)
assert len(dataset) == 1, f"`dataset` should contain 1 frame, not {len(dataset)}"
# init_dataset_return_value = {}
@ -300,22 +303,13 @@ def test_resume_record(tmpdir, request, robot_type, mock):
# return init_dataset_return_value
# with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
dataset = record(
robot,
root,
repo_id,
single_task,
fps=1,
warmup_time_s=0,
episode_time_s=1,
num_episodes=2,
push_to_hub=False,
video=False,
display_cameras=False,
play_sounds=False,
run_compute_stats=False,
)
assert len(dataset) == 2, "`dataset` should contain only 1 frame"
with pytest.raises(FileExistsError):
# Dataset already exists, but resume=False by default
record(**record_kwargs)
dataset = record(**record_kwargs, resume=True)
assert len(dataset) == 2, f"`dataset` should contain 2 frames, not {len(dataset)}"
# assert (
# init_dataset_return_value["num_episodes"] == 2
# ), "`init_dataset` should load the previous episode"