Refactor pusht_zarr
This commit is contained in:
parent
3b5af7eb38
commit
6ad84a6561
|
@ -7,65 +7,63 @@ import torch
|
|||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
|
||||
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
|
||||
PUSHT_FEATURES = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": {
|
||||
"axes": ["x", "y"],
|
||||
},
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": {
|
||||
"axes": ["x", "y"],
|
||||
},
|
||||
},
|
||||
"next.reward": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"next.success": {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"observation.environment_state": {
|
||||
"dtype": "float32",
|
||||
"shape": (16,),
|
||||
"names": [
|
||||
"keypoints",
|
||||
],
|
||||
},
|
||||
"observation.image": {
|
||||
"dtype": None,
|
||||
"shape": (3, 96, 96),
|
||||
"names": [
|
||||
"channel",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
def create_empty_dataset(repo_id, mode):
|
||||
features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": [
|
||||
["x", "y"],
|
||||
],
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": [
|
||||
["x", "y"],
|
||||
],
|
||||
},
|
||||
"next.reward": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"next.success": {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
}
|
||||
|
||||
def build_features(mode: str) -> dict:
|
||||
features = PUSHT_FEATURES
|
||||
if mode == "keypoints":
|
||||
features["observation.environment_state"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (16,),
|
||||
"names": [
|
||||
"keypoints",
|
||||
],
|
||||
}
|
||||
features.pop("observation.image")
|
||||
else:
|
||||
features["observation.image"] = {
|
||||
"dtype": mode,
|
||||
"shape": (3, 96, 96),
|
||||
"names": [
|
||||
"channel",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
}
|
||||
features.pop("observation.environment_state")
|
||||
features["observation.image"]["dtype"] = mode
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=10,
|
||||
robot_type="2d pointer",
|
||||
features=features,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
return dataset
|
||||
return features
|
||||
|
||||
|
||||
def load_raw_dataset(zarr_path, load_images=True):
|
||||
def load_raw_dataset(zarr_path: Path, load_images: bool = True):
|
||||
try:
|
||||
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
|
@ -75,28 +73,10 @@ def load_raw_dataset(zarr_path, load_images=True):
|
|||
raise e
|
||||
|
||||
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
||||
|
||||
env_state = zarr_data["state"][:]
|
||||
agent_pos = env_state[:, :2]
|
||||
block_pos = env_state[:, 2:4]
|
||||
block_angle = env_state[:, 4]
|
||||
|
||||
action = zarr_data["action"][:]
|
||||
|
||||
image = None
|
||||
if load_images:
|
||||
# b h w c
|
||||
image = zarr_data["img"]
|
||||
|
||||
episode_data_index = {
|
||||
"from": np.array([0] + zarr_data.meta["episode_ends"][:-1].tolist()),
|
||||
"to": zarr_data.meta["episode_ends"],
|
||||
}
|
||||
|
||||
return image, agent_pos, block_pos, block_angle, action, episode_data_index
|
||||
return zarr_data
|
||||
|
||||
|
||||
def calculate_coverage(block_pos, block_angle):
|
||||
def calculate_coverage(zarr_data):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
|
@ -104,6 +84,9 @@ def calculate_coverage(block_pos, block_angle):
|
|||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
|
||||
block_pos = zarr_data["state"][:, 2:4]
|
||||
block_angle = zarr_data["state"][:, 4]
|
||||
|
||||
num_frames = len(block_pos)
|
||||
|
||||
coverage = np.zeros((num_frames,))
|
||||
|
@ -139,26 +122,61 @@ def calculate_coverage(block_pos, block_angle):
|
|||
return coverage, keypoints
|
||||
|
||||
|
||||
def calculate_success(coverage, success_threshold):
|
||||
def calculate_success(coverage: float, success_threshold: float):
|
||||
return coverage > success_threshold
|
||||
|
||||
|
||||
def calculate_reward(coverage, success_threshold):
|
||||
def calculate_reward(coverage: float, success_threshold: float):
|
||||
return np.clip(coverage / success_threshold, 0, 1)
|
||||
|
||||
|
||||
def populate_dataset(dataset, episode_data_index, episodes, image, state, env_state, action, reward, success):
|
||||
if episodes is None:
|
||||
episodes = range(len(episode_data_index["from"]))
|
||||
def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = True):
|
||||
if mode not in ["video", "image", "keypoints"]:
|
||||
raise ValueError(mode)
|
||||
|
||||
if (LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
|
||||
|
||||
zarr_data = load_raw_dataset(zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr")
|
||||
|
||||
env_state = zarr_data["state"][:]
|
||||
agent_pos = env_state[:, :2]
|
||||
|
||||
action = zarr_data["action"][:]
|
||||
image = zarr_data["img"] # (b, h, w, c)
|
||||
|
||||
episode_data_index = {
|
||||
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
|
||||
"to": zarr_data.meta["episode_ends"],
|
||||
}
|
||||
|
||||
# Calculate success and reward based on the overlapping area
|
||||
# of the T-object and the T-area.
|
||||
coverage, keypoints = calculate_coverage(zarr_data)
|
||||
success = calculate_success(coverage, success_threshold=0.95)
|
||||
reward = calculate_reward(coverage, success_threshold=0.95)
|
||||
|
||||
features = build_features(mode)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=10,
|
||||
robot_type="2d pointer",
|
||||
features=features,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
episodes = range(len(episode_data_index["from"]))
|
||||
for ep_idx in episodes:
|
||||
from_idx = episode_data_index["from"][ep_idx]
|
||||
to_idx = episode_data_index["to"][ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
|
||||
for frame_idx in range(num_frames):
|
||||
i = from_idx + frame_idx
|
||||
# frame = extract_frame_from_zarr(zarr_data, frame_idx)
|
||||
|
||||
i = from_idx + frame_idx
|
||||
frame = {
|
||||
"action": torch.from_numpy(action[i]),
|
||||
# Shift reward and success by +1 until the last item of the episode
|
||||
|
@ -166,54 +184,17 @@ def populate_dataset(dataset, episode_data_index, episodes, image, state, env_st
|
|||
"next.success": success[i + (frame_idx < num_frames - 1)],
|
||||
}
|
||||
|
||||
frame["observation.state"] = torch.from_numpy(state[i])
|
||||
frame["observation.state"] = torch.from_numpy(agent_pos[i])
|
||||
|
||||
if env_state is not None:
|
||||
frame["observation.environment_state"] = torch.from_numpy(env_state[i])
|
||||
|
||||
if image is not None:
|
||||
if mode == "keypoints":
|
||||
frame["observation.environment_state"] = torch.from_numpy(keypoints[i])
|
||||
else:
|
||||
frame["observation.image"] = torch.from_numpy(image[i])
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode(task="Push the T-shaped blue block onto the T-shaped green target surface.")
|
||||
dataset.save_episode(task=PUSHT_TASK)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def port_pusht(raw_dir, repo_id, episodes=None, mode="video", push_to_hub=True):
|
||||
if mode not in ["video", "image", "keypoints"]:
|
||||
raise ValueError(mode)
|
||||
|
||||
if (LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
|
||||
raw_dir = Path(raw_dir)
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
|
||||
|
||||
image, agent_pos, block_pos, block_angle, action, episode_data_index = load_raw_dataset(
|
||||
zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr"
|
||||
)
|
||||
|
||||
# Calculate success and reward based on the overlapping area
|
||||
# of the T-object and the T-area.
|
||||
coverage, keypoints = calculate_coverage(block_pos, block_angle)
|
||||
success = calculate_success(coverage, success_threshold=0.95)
|
||||
reward = calculate_reward(coverage, success_threshold=0.95)
|
||||
|
||||
dataset = create_empty_dataset(repo_id, mode)
|
||||
dataset = populate_dataset(
|
||||
dataset,
|
||||
episode_data_index,
|
||||
episodes,
|
||||
image=None if mode == "keypoints" else image,
|
||||
state=agent_pos,
|
||||
env_state=keypoints if mode == "keypoints" else None,
|
||||
action=action,
|
||||
reward=reward,
|
||||
success=success,
|
||||
)
|
||||
dataset.consolidate()
|
||||
|
||||
if push_to_hub:
|
||||
|
@ -224,23 +205,20 @@ if __name__ == "__main__":
|
|||
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
|
||||
repo_id = "lerobot/pusht"
|
||||
|
||||
episodes = None
|
||||
# Uncomment if you want to try with a subset (episode 0 and 1)
|
||||
# episodes = [0, 1]
|
||||
|
||||
modes = ["video", "image", "keypoints"]
|
||||
# Uncomment if you want to try with a specific mode
|
||||
# modes = ["video"]
|
||||
# modes = ["image"]
|
||||
# modes = ["keypoints"]
|
||||
|
||||
for mode in ["video", "image", "keypoints"]:
|
||||
raw_dir = Path("data/lerobot-raw/pusht_raw")
|
||||
for mode in modes:
|
||||
if mode in ["image", "keypoints"]:
|
||||
repo_id += f"_{mode}"
|
||||
|
||||
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
|
||||
port_pusht("data/lerobot-raw/pusht_raw", repo_id=repo_id, mode=mode, episodes=episodes)
|
||||
main(raw_dir, repo_id=repo_id, mode=mode)
|
||||
|
||||
# Uncomment if you want to loal the local dataset and explore it
|
||||
# Uncomment if you want to load the local dataset and explore it
|
||||
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
|
||||
# breakpoint()
|
||||
|
|
|
@ -678,7 +678,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"})',\n"
|
||||
)
|
||||
|
||||
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
return {
|
||||
"size": 0,
|
||||
|
@ -709,7 +709,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# check the dtype and shape matches, etc.
|
||||
|
||||
if self.episode_buffer is None:
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
frame_index = self.episode_buffer["size"]
|
||||
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
|
@ -795,7 +795,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
episode_buffer[key] = video_paths[key]
|
||||
|
||||
if not episode_data: # Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
self.consolidated = False
|
||||
|
||||
|
@ -817,7 +817,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
shutil.rmtree(img_dir)
|
||||
|
||||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
|
||||
if isinstance(self.image_writer, AsyncImageWriter):
|
||||
|
@ -941,7 +941,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
||||
|
||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||
obj.episode_buffer = obj._create_episode_buffer()
|
||||
obj.episode_buffer = obj.create_episode_buffer()
|
||||
|
||||
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
|
||||
# is used to know when certain operations are need (for instance, computing dataset statistics). In
|
||||
|
|
Loading…
Reference in New Issue