215 lines
8.0 KiB
Python
215 lines
8.0 KiB
Python
from pathlib import Path
|
|
|
|
import einops
|
|
import numpy as np
|
|
import torch
|
|
import tqdm
|
|
|
|
from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps
|
|
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
|
|
|
# as define in env
|
|
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
|
|
|
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
|
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
|
|
|
|
|
|
class PushtDataset(torch.utils.data.Dataset):
|
|
"""
|
|
|
|
Arguments
|
|
----------
|
|
delta_timestamps : dict[list[float]] | None, optional
|
|
Loads data from frames with a shift in timestamps with a different strategy for each data key (e.g. state, action or image)
|
|
If `None`, no shift is applied to current timestamp and the data from the current frame is loaded.
|
|
"""
|
|
|
|
available_datasets = ["pusht"]
|
|
fps = 10
|
|
image_keys = ["observation.image"]
|
|
|
|
def __init__(
|
|
self,
|
|
dataset_id: str,
|
|
version: str | None = "v1.2",
|
|
root: Path | None = None,
|
|
transform: callable = None,
|
|
delta_timestamps: dict[list[float]] | None = None,
|
|
):
|
|
super().__init__()
|
|
self.dataset_id = dataset_id
|
|
self.version = version
|
|
self.root = root
|
|
self.transform = transform
|
|
self.delta_timestamps = delta_timestamps
|
|
|
|
self.data_dir = self.root / f"{self.dataset_id}"
|
|
if (self.data_dir / "data_dict.pth").exists() and (
|
|
self.data_dir / "data_ids_per_episode.pth"
|
|
).exists():
|
|
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
|
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
|
else:
|
|
self._download_and_preproc_obsolete()
|
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
|
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
|
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
|
|
|
@property
|
|
def num_samples(self) -> int:
|
|
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
|
|
|
|
@property
|
|
def num_episodes(self) -> int:
|
|
return len(self.data_ids_per_episode)
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
def __getitem__(self, idx):
|
|
item = {}
|
|
|
|
# get episode id and timestamp of the sampled frame
|
|
current_ts = self.data_dict["timestamp"][idx].item()
|
|
episode = self.data_dict["episode"][idx].item()
|
|
|
|
for key in self.data_dict:
|
|
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
|
data, is_pad = load_data_with_delta_timestamps(
|
|
self.data_dict,
|
|
self.data_ids_per_episode,
|
|
self.delta_timestamps,
|
|
key,
|
|
current_ts,
|
|
episode,
|
|
)
|
|
item[key] = data
|
|
item[f"{key}_is_pad"] = is_pad
|
|
else:
|
|
item[key] = self.data_dict[key][idx]
|
|
|
|
if self.transform is not None:
|
|
item = self.transform(item)
|
|
|
|
return item
|
|
|
|
def _download_and_preproc_obsolete(self):
|
|
try:
|
|
import pymunk
|
|
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
|
except ModuleNotFoundError as e:
|
|
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
|
raise e
|
|
|
|
assert self.root is not None
|
|
raw_dir = self.root / f"{self.dataset_id}_raw"
|
|
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
|
if not zarr_path.is_dir():
|
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
|
download_and_extract_zip(PUSHT_URL, raw_dir)
|
|
|
|
# load
|
|
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(
|
|
zarr_path
|
|
) # , keys=['img', 'state', 'action'])
|
|
|
|
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
|
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
|
total_frames = dataset_dict["action"].shape[0]
|
|
# to create test artifact
|
|
# num_episodes = 1
|
|
# total_frames = 50
|
|
assert len(
|
|
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
|
|
), "Some data type dont have the same number of total frames."
|
|
|
|
# TODO: verify that goal pose is expected to be fixed
|
|
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
|
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
|
|
|
imgs = torch.from_numpy(dataset_dict["img"])
|
|
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
|
|
states = torch.from_numpy(dataset_dict["state"])
|
|
actions = torch.from_numpy(dataset_dict["action"])
|
|
|
|
self.data_ids_per_episode = {}
|
|
ep_dicts = []
|
|
|
|
idx0 = 0
|
|
for episode_id in tqdm.tqdm(range(num_episodes)):
|
|
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
|
|
|
num_frames = idx1 - idx0
|
|
|
|
assert (episode_ids[idx0:idx1] == episode_id).all()
|
|
|
|
image = imgs[idx0:idx1]
|
|
assert image.min() >= 0.0
|
|
assert image.max() <= 255.0
|
|
image = image.type(torch.uint8)
|
|
|
|
state = states[idx0:idx1]
|
|
agent_pos = state[:, :2]
|
|
block_pos = state[:, 2:4]
|
|
block_angle = state[:, 4]
|
|
|
|
reward = torch.zeros(num_frames)
|
|
success = torch.zeros(num_frames, dtype=torch.bool)
|
|
done = torch.zeros(num_frames, dtype=torch.bool)
|
|
for i in range(num_frames):
|
|
space = pymunk.Space()
|
|
space.gravity = 0, 0
|
|
space.damping = 0
|
|
|
|
# Add walls.
|
|
walls = [
|
|
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
|
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
|
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
|
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
|
]
|
|
space.add(*walls)
|
|
|
|
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
|
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
|
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
|
intersection_area = goal_geom.intersection(block_geom).area
|
|
goal_area = goal_geom.area
|
|
coverage = intersection_area / goal_area
|
|
reward[i] = np.clip(coverage / SUCCESS_THRESHOLD, 0, 1)
|
|
success[i] = coverage > SUCCESS_THRESHOLD
|
|
|
|
# last step of demonstration is considered done
|
|
done[-1] = True
|
|
|
|
ep_dict = {
|
|
"observation.image": image,
|
|
"observation.state": agent_pos,
|
|
"action": actions[idx0:idx1],
|
|
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
|
"frame_id": torch.arange(0, num_frames, 1),
|
|
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
|
# "next.observation.image": image[1:],
|
|
# "next.observation.state": agent_pos[1:],
|
|
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
|
|
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
|
"next.done": torch.cat([done[1:], done[[-1]]]),
|
|
"next.success": torch.cat([success[1:], success[[-1]]]),
|
|
}
|
|
ep_dicts.append(ep_dict)
|
|
|
|
assert isinstance(episode_id, int)
|
|
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
|
|
assert len(self.data_ids_per_episode[episode_id]) == num_frames
|
|
|
|
idx0 = idx1
|
|
|
|
self.data_dict = {}
|
|
|
|
keys = ep_dicts[0].keys()
|
|
for key in keys:
|
|
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
|
|
|
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|