diff --git a/examples/port_datasets/pusht_zarr.py b/examples/port_datasets/pusht_zarr.py index 742d1346..7313f9f1 100644 --- a/examples/port_datasets/pusht_zarr.py +++ b/examples/port_datasets/pusht_zarr.py @@ -7,65 +7,63 @@ import torch from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw +PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface." +PUSHT_FEATURES = { + "observation.state": { + "dtype": "float32", + "shape": (2,), + "names": { + "axes": ["x", "y"], + }, + }, + "action": { + "dtype": "float32", + "shape": (2,), + "names": { + "axes": ["x", "y"], + }, + }, + "next.reward": { + "dtype": "float32", + "shape": (1,), + "names": None, + }, + "next.success": { + "dtype": "bool", + "shape": (1,), + "names": None, + }, + "observation.environment_state": { + "dtype": "float32", + "shape": (16,), + "names": [ + "keypoints", + ], + }, + "observation.image": { + "dtype": None, + "shape": (3, 96, 96), + "names": [ + "channel", + "height", + "width", + ], + }, +} -def create_empty_dataset(repo_id, mode): - features = { - "observation.state": { - "dtype": "float32", - "shape": (2,), - "names": [ - ["x", "y"], - ], - }, - "action": { - "dtype": "float32", - "shape": (2,), - "names": [ - ["x", "y"], - ], - }, - "next.reward": { - "dtype": "float32", - "shape": (1,), - "names": None, - }, - "next.success": { - "dtype": "bool", - "shape": (1,), - "names": None, - }, - } +def build_features(mode: str) -> dict: + features = PUSHT_FEATURES if mode == "keypoints": - features["observation.environment_state"] = { - "dtype": "float32", - "shape": (16,), - "names": [ - "keypoints", - ], - } + features.pop("observation.image") else: - features["observation.image"] = { - "dtype": mode, - "shape": (3, 96, 96), - "names": [ - "channel", - "height", - "width", - ], - } + features.pop("observation.environment_state") + features["observation.image"]["dtype"] = mode - dataset = LeRobotDataset.create( - repo_id=repo_id, - fps=10, - robot_type="2d pointer", - features=features, - image_writer_threads=4, - ) - return dataset + return features -def load_raw_dataset(zarr_path, load_images=True): +def load_raw_dataset(zarr_path: Path, load_images: bool = True): try: from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import ( ReplayBuffer as DiffusionPolicyReplayBuffer, @@ -75,28 +73,10 @@ def load_raw_dataset(zarr_path, load_images=True): raise e zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) - - env_state = zarr_data["state"][:] - agent_pos = env_state[:, :2] - block_pos = env_state[:, 2:4] - block_angle = env_state[:, 4] - - action = zarr_data["action"][:] - - image = None - if load_images: - # b h w c - image = zarr_data["img"] - - episode_data_index = { - "from": np.array([0] + zarr_data.meta["episode_ends"][:-1].tolist()), - "to": zarr_data.meta["episode_ends"], - } - - return image, agent_pos, block_pos, block_angle, action, episode_data_index + return zarr_data -def calculate_coverage(block_pos, block_angle): +def calculate_coverage(zarr_data): try: import pymunk from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely @@ -104,6 +84,9 @@ def calculate_coverage(block_pos, block_angle): print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`") raise e + block_pos = zarr_data["state"][:, 2:4] + block_angle = zarr_data["state"][:, 4] + num_frames = len(block_pos) coverage = np.zeros((num_frames,)) @@ -139,26 +122,61 @@ def calculate_coverage(block_pos, block_angle): return coverage, keypoints -def calculate_success(coverage, success_threshold): +def calculate_success(coverage: float, success_threshold: float): return coverage > success_threshold -def calculate_reward(coverage, success_threshold): +def calculate_reward(coverage: float, success_threshold: float): return np.clip(coverage / success_threshold, 0, 1) -def populate_dataset(dataset, episode_data_index, episodes, image, state, env_state, action, reward, success): - if episodes is None: - episodes = range(len(episode_data_index["from"])) +def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = True): + if mode not in ["video", "image", "keypoints"]: + raise ValueError(mode) + if (LEROBOT_HOME / repo_id).exists(): + shutil.rmtree(LEROBOT_HOME / repo_id) + + if not raw_dir.exists(): + download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw") + + zarr_data = load_raw_dataset(zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr") + + env_state = zarr_data["state"][:] + agent_pos = env_state[:, :2] + + action = zarr_data["action"][:] + image = zarr_data["img"] # (b, h, w, c) + + episode_data_index = { + "from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])), + "to": zarr_data.meta["episode_ends"], + } + + # Calculate success and reward based on the overlapping area + # of the T-object and the T-area. + coverage, keypoints = calculate_coverage(zarr_data) + success = calculate_success(coverage, success_threshold=0.95) + reward = calculate_reward(coverage, success_threshold=0.95) + + features = build_features(mode) + dataset = LeRobotDataset.create( + repo_id=repo_id, + fps=10, + robot_type="2d pointer", + features=features, + image_writer_threads=4, + ) + episodes = range(len(episode_data_index["from"])) for ep_idx in episodes: from_idx = episode_data_index["from"][ep_idx] to_idx = episode_data_index["to"][ep_idx] num_frames = to_idx - from_idx for frame_idx in range(num_frames): - i = from_idx + frame_idx + # frame = extract_frame_from_zarr(zarr_data, frame_idx) + i = from_idx + frame_idx frame = { "action": torch.from_numpy(action[i]), # Shift reward and success by +1 until the last item of the episode @@ -166,54 +184,17 @@ def populate_dataset(dataset, episode_data_index, episodes, image, state, env_st "next.success": success[i + (frame_idx < num_frames - 1)], } - frame["observation.state"] = torch.from_numpy(state[i]) + frame["observation.state"] = torch.from_numpy(agent_pos[i]) - if env_state is not None: - frame["observation.environment_state"] = torch.from_numpy(env_state[i]) - - if image is not None: + if mode == "keypoints": + frame["observation.environment_state"] = torch.from_numpy(keypoints[i]) + else: frame["observation.image"] = torch.from_numpy(image[i]) dataset.add_frame(frame) - dataset.save_episode(task="Push the T-shaped blue block onto the T-shaped green target surface.") + dataset.save_episode(task=PUSHT_TASK) - return dataset - - -def port_pusht(raw_dir, repo_id, episodes=None, mode="video", push_to_hub=True): - if mode not in ["video", "image", "keypoints"]: - raise ValueError(mode) - - if (LEROBOT_HOME / repo_id).exists(): - shutil.rmtree(LEROBOT_HOME / repo_id) - - raw_dir = Path(raw_dir) - if not raw_dir.exists(): - download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw") - - image, agent_pos, block_pos, block_angle, action, episode_data_index = load_raw_dataset( - zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr" - ) - - # Calculate success and reward based on the overlapping area - # of the T-object and the T-area. - coverage, keypoints = calculate_coverage(block_pos, block_angle) - success = calculate_success(coverage, success_threshold=0.95) - reward = calculate_reward(coverage, success_threshold=0.95) - - dataset = create_empty_dataset(repo_id, mode) - dataset = populate_dataset( - dataset, - episode_data_index, - episodes, - image=None if mode == "keypoints" else image, - state=agent_pos, - env_state=keypoints if mode == "keypoints" else None, - action=action, - reward=reward, - success=success, - ) dataset.consolidate() if push_to_hub: @@ -224,23 +205,20 @@ if __name__ == "__main__": # To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht) repo_id = "lerobot/pusht" - episodes = None - # Uncomment if you want to try with a subset (episode 0 and 1) - # episodes = [0, 1] - modes = ["video", "image", "keypoints"] # Uncomment if you want to try with a specific mode # modes = ["video"] # modes = ["image"] # modes = ["keypoints"] - for mode in ["video", "image", "keypoints"]: + raw_dir = Path("data/lerobot-raw/pusht_raw") + for mode in modes: if mode in ["image", "keypoints"]: repo_id += f"_{mode}" # download and load raw dataset, create LeRobotDataset, populate it, push to hub - port_pusht("data/lerobot-raw/pusht_raw", repo_id=repo_id, mode=mode, episodes=episodes) + main(raw_dir, repo_id=repo_id, mode=mode) - # Uncomment if you want to loal the local dataset and explore it + # Uncomment if you want to load the local dataset and explore it # dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True) # breakpoint() diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 090134e2..78a3aeaa 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -678,7 +678,7 @@ class LeRobotDataset(torch.utils.data.Dataset): "})',\n" ) - def _create_episode_buffer(self, episode_index: int | None = None) -> dict: + def create_episode_buffer(self, episode_index: int | None = None) -> dict: current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index return { "size": 0, @@ -709,7 +709,7 @@ class LeRobotDataset(torch.utils.data.Dataset): # check the dtype and shape matches, etc. if self.episode_buffer is None: - self.episode_buffer = self._create_episode_buffer() + self.episode_buffer = self.create_episode_buffer() frame_index = self.episode_buffer["size"] timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps @@ -795,7 +795,7 @@ class LeRobotDataset(torch.utils.data.Dataset): episode_buffer[key] = video_paths[key] if not episode_data: # Reset the buffer - self.episode_buffer = self._create_episode_buffer() + self.episode_buffer = self.create_episode_buffer() self.consolidated = False @@ -817,7 +817,7 @@ class LeRobotDataset(torch.utils.data.Dataset): shutil.rmtree(img_dir) # Reset the buffer - self.episode_buffer = self._create_episode_buffer() + self.episode_buffer = self.create_episode_buffer() def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None: if isinstance(self.image_writer, AsyncImageWriter): @@ -941,7 +941,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.start_image_writer(image_writer_processes, image_writer_threads) # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer - obj.episode_buffer = obj._create_episode_buffer() + obj.episode_buffer = obj.create_episode_buffer() # This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It # is used to know when certain operations are need (for instance, computing dataset statistics). In