WIP
WIP WIP train.py works, loss going down WIP eval.py Fix WIP (eval running, TODO: verify results reproduced) Eval works! (testing reproducibility) WIP pretrained model pusht reproduces same results as torchrl pretrained model pusht reproduces same results as torchrl Remove AbstractPolicy, Move all queues in select_action WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
parent
920e0d118b
commit
1cdfbc8b52
|
@ -1,26 +1,13 @@
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import gdown
|
import gdown
|
||||||
import h5py
|
import h5py
|
||||||
import torch
|
import torch
|
||||||
import torchrl
|
|
||||||
import tqdm
|
import tqdm
|
||||||
from tensordict import TensorDict
|
|
||||||
from torchrl.data.replay_buffers.samplers import Sampler
|
|
||||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
|
||||||
from torchrl.data.replay_buffers.writers import Writer
|
|
||||||
|
|
||||||
from lerobot.common.datasets.abstract import AbstractDataset
|
from lerobot.common.datasets.utils import load_data_with_delta_timestamps
|
||||||
|
|
||||||
DATASET_IDS = [
|
|
||||||
"aloha_sim_insertion_human",
|
|
||||||
"aloha_sim_insertion_scripted",
|
|
||||||
"aloha_sim_transfer_cube_human",
|
|
||||||
"aloha_sim_transfer_cube_scripted",
|
|
||||||
]
|
|
||||||
|
|
||||||
FOLDER_URLS = {
|
FOLDER_URLS = {
|
||||||
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
||||||
|
@ -66,7 +53,6 @@ CAMERAS = {
|
||||||
|
|
||||||
|
|
||||||
def download(data_dir, dataset_id):
|
def download(data_dir, dataset_id):
|
||||||
assert dataset_id in DATASET_IDS
|
|
||||||
assert dataset_id in FOLDER_URLS
|
assert dataset_id in FOLDER_URLS
|
||||||
assert dataset_id in EP48_URLS
|
assert dataset_id in EP48_URLS
|
||||||
assert dataset_id in EP49_URLS
|
assert dataset_id in EP49_URLS
|
||||||
|
@ -80,51 +66,78 @@ def download(data_dir, dataset_id):
|
||||||
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
|
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
|
||||||
|
|
||||||
|
|
||||||
class AlohaDataset(AbstractDataset):
|
class AlohaDataset(torch.utils.data.Dataset):
|
||||||
available_datasets = DATASET_IDS
|
available_datasets = [
|
||||||
|
"aloha_sim_insertion_human",
|
||||||
|
"aloha_sim_insertion_scripted",
|
||||||
|
"aloha_sim_transfer_cube_human",
|
||||||
|
"aloha_sim_transfer_cube_scripted",
|
||||||
|
]
|
||||||
|
fps = 50
|
||||||
|
image_keys = ["observation.images.top"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
version: str | None = "v1.2",
|
version: str | None = "v1.2",
|
||||||
batch_size: int | None = None,
|
|
||||||
*,
|
|
||||||
shuffle: bool = True,
|
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
pin_memory: bool = False,
|
transform: callable = None,
|
||||||
prefetch: int = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
sampler: Sampler | None = None,
|
|
||||||
collate_fn: Callable | None = None,
|
|
||||||
writer: Writer | None = None,
|
|
||||||
transform: "torchrl.envs.Transform" = None,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__()
|
||||||
dataset_id,
|
self.dataset_id = dataset_id
|
||||||
version,
|
self.version = version
|
||||||
batch_size,
|
self.root = root
|
||||||
shuffle=shuffle,
|
self.transform = transform
|
||||||
root=root,
|
self.delta_timestamps = delta_timestamps
|
||||||
pin_memory=pin_memory,
|
|
||||||
prefetch=prefetch,
|
data_dir = self.root / f"{self.dataset_id}"
|
||||||
sampler=sampler,
|
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
|
||||||
collate_fn=collate_fn,
|
self.data_dict = torch.load(data_dir / "data_dict.pth")
|
||||||
writer=writer,
|
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
|
||||||
transform=transform,
|
else:
|
||||||
|
self._download_and_preproc_obsolete()
|
||||||
|
data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
torch.save(self.data_dict, data_dir / "data_dict.pth")
|
||||||
|
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_samples(self) -> int:
|
||||||
|
return len(self.data_dict["index"])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_episodes(self) -> int:
|
||||||
|
return len(self.data_ids_per_episode)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
item = {}
|
||||||
|
|
||||||
|
# get episode id and timestamp of the sampled frame
|
||||||
|
current_ts = self.data_dict["timestamp"][idx].item()
|
||||||
|
episode = self.data_dict["episode"][idx].item()
|
||||||
|
|
||||||
|
for key in self.data_dict:
|
||||||
|
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
||||||
|
data, is_pad = load_data_with_delta_timestamps(
|
||||||
|
self.data_dict,
|
||||||
|
self.data_ids_per_episode,
|
||||||
|
self.delta_timestamps,
|
||||||
|
key,
|
||||||
|
current_ts,
|
||||||
|
episode,
|
||||||
)
|
)
|
||||||
|
item[key] = data
|
||||||
|
item[f"{key}_is_pad"] = is_pad
|
||||||
|
else:
|
||||||
|
item[key] = self.data_dict[key][idx]
|
||||||
|
|
||||||
@property
|
if self.transform is not None:
|
||||||
def stats_patterns(self) -> dict:
|
item = self.transform(item)
|
||||||
d = {
|
|
||||||
("observation", "state"): "b c -> c",
|
|
||||||
("action",): "b c -> c",
|
|
||||||
}
|
|
||||||
for cam in CAMERAS[self.dataset_id]:
|
|
||||||
d[("observation", "image", cam)] = "b c h w -> c 1 1"
|
|
||||||
return d
|
|
||||||
|
|
||||||
@property
|
return item
|
||||||
def image_keys(self) -> list:
|
|
||||||
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
|
|
||||||
|
|
||||||
def _download_and_preproc_obsolete(self):
|
def _download_and_preproc_obsolete(self):
|
||||||
assert self.root is not None
|
assert self.root is not None
|
||||||
|
@ -132,54 +145,55 @@ class AlohaDataset(AbstractDataset):
|
||||||
if not raw_dir.is_dir():
|
if not raw_dir.is_dir():
|
||||||
download(raw_dir, self.dataset_id)
|
download(raw_dir, self.dataset_id)
|
||||||
|
|
||||||
total_num_frames = 0
|
total_frames = 0
|
||||||
logging.info("Compute total number of frames to initialize offline buffer")
|
logging.info("Compute total number of frames to initialize offline buffer")
|
||||||
for ep_id in range(NUM_EPISODES[self.dataset_id]):
|
for ep_id in range(NUM_EPISODES[self.dataset_id]):
|
||||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||||
with h5py.File(ep_path, "r") as ep:
|
with h5py.File(ep_path, "r") as ep:
|
||||||
total_num_frames += ep["/action"].shape[0] - 1
|
total_frames += ep["/action"].shape[0] - 1
|
||||||
logging.info(f"{total_num_frames=}")
|
logging.info(f"{total_frames=}")
|
||||||
|
|
||||||
|
self.data_ids_per_episode = {}
|
||||||
|
ep_dicts = []
|
||||||
|
|
||||||
logging.info("Initialize and feed offline buffer")
|
logging.info("Initialize and feed offline buffer")
|
||||||
idxtd = 0
|
|
||||||
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
|
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
|
||||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||||
with h5py.File(ep_path, "r") as ep:
|
with h5py.File(ep_path, "r") as ep:
|
||||||
ep_num_frames = ep["/action"].shape[0]
|
num_frames = ep["/action"].shape[0]
|
||||||
|
|
||||||
# last step of demonstration is considered done
|
# last step of demonstration is considered done
|
||||||
done = torch.zeros(ep_num_frames, 1, dtype=torch.bool)
|
done = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||||
done[-1] = True
|
done[-1] = True
|
||||||
|
|
||||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||||
action = torch.from_numpy(ep["/action"][:])
|
action = torch.from_numpy(ep["/action"][:])
|
||||||
|
|
||||||
ep_td = TensorDict(
|
ep_dict = {
|
||||||
{
|
"observation.state": state,
|
||||||
("observation", "state"): state[:-1],
|
"action": action,
|
||||||
"action": action[:-1],
|
"episode": torch.tensor([ep_id] * num_frames),
|
||||||
"episode": torch.tensor([ep_id] * (ep_num_frames - 1)),
|
"frame_id": torch.arange(0, num_frames, 1),
|
||||||
"frame_id": torch.arange(0, ep_num_frames - 1, 1),
|
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||||
("next", "observation", "state"): state[1:],
|
# "next.observation.state": state,
|
||||||
# TODO: compute reward and success
|
# TODO(rcadene): compute reward and success
|
||||||
# ("next", "reward"): reward[1:],
|
# "next.reward": reward[1:],
|
||||||
("next", "done"): done[1:],
|
"next.done": done[1:],
|
||||||
# ("next", "success"): success[1:],
|
# "next.success": success[1:],
|
||||||
},
|
}
|
||||||
batch_size=ep_num_frames - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
for cam in CAMERAS[self.dataset_id]:
|
for cam in CAMERAS[self.dataset_id]:
|
||||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
|
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
|
||||||
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
||||||
ep_td["observation", "image", cam] = image[:-1]
|
ep_dict[f"observation.images.{cam}"] = image[:-1]
|
||||||
ep_td["next", "observation", "image", cam] = image[1:]
|
# ep_dict[f"next.observation.images.{cam}"] = image[1:]
|
||||||
|
|
||||||
if ep_id == 0:
|
ep_dicts.append(ep_dict)
|
||||||
# hack to initialize tensordict data structure to store episodes
|
|
||||||
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}")
|
|
||||||
|
|
||||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
self.data_dict = {}
|
||||||
idxtd = idxtd + len(ep_td)
|
|
||||||
|
|
||||||
return TensorStorage(td_data.lock_())
|
keys = ep_dicts[0].keys()
|
||||||
|
for key in keys:
|
||||||
|
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||||
|
|
||||||
|
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
|
from torchvision.transforms import v2
|
||||||
|
|
||||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
from lerobot.common.transforms import Prod
|
||||||
|
|
||||||
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
||||||
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
|
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
|
||||||
|
@ -13,57 +12,12 @@ from lerobot.common.transforms import NormalizeTransform, Prod
|
||||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||||
|
|
||||||
|
|
||||||
def make_offline_buffer(
|
def make_dataset(
|
||||||
cfg,
|
cfg,
|
||||||
overwrite_sampler=None,
|
|
||||||
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
|
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
|
||||||
normalize=True,
|
normalize=True,
|
||||||
overwrite_batch_size=None,
|
|
||||||
overwrite_prefetch=None,
|
|
||||||
stats_path=None,
|
stats_path=None,
|
||||||
):
|
):
|
||||||
if cfg.policy.balanced_sampling:
|
|
||||||
assert cfg.online_steps > 0
|
|
||||||
batch_size = None
|
|
||||||
pin_memory = False
|
|
||||||
prefetch = None
|
|
||||||
else:
|
|
||||||
assert cfg.online_steps == 0
|
|
||||||
num_slices = cfg.policy.batch_size
|
|
||||||
batch_size = cfg.policy.horizon * num_slices
|
|
||||||
pin_memory = cfg.device == "cuda"
|
|
||||||
prefetch = cfg.prefetch
|
|
||||||
|
|
||||||
if overwrite_batch_size is not None:
|
|
||||||
batch_size = overwrite_batch_size
|
|
||||||
|
|
||||||
if overwrite_prefetch is not None:
|
|
||||||
prefetch = overwrite_prefetch
|
|
||||||
|
|
||||||
if overwrite_sampler is None:
|
|
||||||
# TODO(rcadene): move batch_size outside
|
|
||||||
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
|
|
||||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
|
||||||
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
|
||||||
|
|
||||||
if cfg.offline_prioritized_sampler:
|
|
||||||
logging.info("use prioritized sampler for offline dataset")
|
|
||||||
sampler = PrioritizedSliceSampler(
|
|
||||||
max_capacity=100_000,
|
|
||||||
alpha=cfg.policy.per_alpha,
|
|
||||||
beta=cfg.policy.per_beta,
|
|
||||||
num_slices=num_traj_per_batch,
|
|
||||||
strict_length=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logging.info("use simple sampler for offline dataset")
|
|
||||||
sampler = SliceSampler(
|
|
||||||
num_slices=num_traj_per_batch,
|
|
||||||
strict_length=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sampler = overwrite_sampler
|
|
||||||
|
|
||||||
if cfg.env.name == "simxarm":
|
if cfg.env.name == "simxarm":
|
||||||
from lerobot.common.datasets.simxarm import SimxarmDataset
|
from lerobot.common.datasets.simxarm import SimxarmDataset
|
||||||
|
|
||||||
|
@ -81,56 +35,56 @@ def make_offline_buffer(
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.env.name)
|
raise ValueError(cfg.env.name)
|
||||||
|
|
||||||
offline_buffer = clsfunc(
|
transforms = None
|
||||||
dataset_id=cfg.dataset_id,
|
|
||||||
sampler=sampler,
|
|
||||||
batch_size=batch_size,
|
|
||||||
root=DATA_DIR,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.policy.name == "tdmpc":
|
|
||||||
img_keys = []
|
|
||||||
for key in offline_buffer.image_keys:
|
|
||||||
img_keys.append(("next", *key))
|
|
||||||
img_keys += offline_buffer.image_keys
|
|
||||||
else:
|
|
||||||
img_keys = offline_buffer.image_keys
|
|
||||||
|
|
||||||
if normalize:
|
if normalize:
|
||||||
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
|
|
||||||
|
|
||||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
||||||
# min_max_from_spec
|
# min_max_from_spec
|
||||||
stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
|
# stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
|
||||||
|
|
||||||
# we only normalize the state and action, since the images are usually normalized inside the model for
|
stats = {}
|
||||||
# now (except for tdmpc: see the following)
|
|
||||||
in_keys = [("observation", "state"), ("action")]
|
|
||||||
|
|
||||||
if cfg.policy.name == "tdmpc":
|
|
||||||
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
|
|
||||||
in_keys += img_keys
|
|
||||||
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
|
|
||||||
in_keys += [("next", *key) for key in img_keys]
|
|
||||||
in_keys.append(("next", "observation", "state"))
|
|
||||||
|
|
||||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||||
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
stats["observation.state"] = {}
|
||||||
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
stats["action"] = {}
|
||||||
|
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||||
|
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||||
|
|
||||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
# normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||||
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
|
|
||||||
|
|
||||||
offline_buffer.set_transform(transforms)
|
transforms = v2.Compose(
|
||||||
|
[
|
||||||
|
# TODO(rcadene): we need to do something about image_keys
|
||||||
|
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||||
|
# NormalizeTransform(
|
||||||
|
# stats,
|
||||||
|
# in_keys=[
|
||||||
|
# "observation.state",
|
||||||
|
# "action",
|
||||||
|
# ],
|
||||||
|
# mode=normalization_mode,
|
||||||
|
# ),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if not overwrite_sampler:
|
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||||
index = torch.arange(0, offline_buffer.num_samples, 1)
|
# TODO(rcadene): implement delta_timestamps in config
|
||||||
sampler.extend(index)
|
delta_timestamps = {
|
||||||
|
"observation.image": [-0.1, 0],
|
||||||
|
"observation.state": [-0.1, 0],
|
||||||
|
"action": [-0.1] + [i / clsfunc.fps for i in range(15)],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
delta_timestamps = None
|
||||||
|
|
||||||
return offline_buffer
|
dataset = clsfunc(
|
||||||
|
dataset_id=cfg.dataset_id,
|
||||||
|
root=DATA_DIR,
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
transform=transforms,
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
|
@ -1,20 +1,13 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pygame
|
import pygame
|
||||||
import pymunk
|
import pymunk
|
||||||
import torch
|
import torch
|
||||||
import torchrl
|
|
||||||
import tqdm
|
import tqdm
|
||||||
from tensordict import TensorDict
|
|
||||||
from torchrl.data.replay_buffers.samplers import Sampler
|
|
||||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
|
||||||
from torchrl.data.replay_buffers.writers import Writer
|
|
||||||
|
|
||||||
from lerobot.common.datasets.abstract import AbstractDataset
|
from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps
|
||||||
from lerobot.common.datasets.utils import download_and_extract_zip
|
|
||||||
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
|
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
|
||||||
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||||
|
|
||||||
|
@ -83,37 +76,82 @@ def add_tee(
|
||||||
return body
|
return body
|
||||||
|
|
||||||
|
|
||||||
class PushtDataset(AbstractDataset):
|
class PushtDataset(torch.utils.data.Dataset):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
----------
|
||||||
|
delta_timestamps : dict[list[float]] | None, optional
|
||||||
|
Loads data from frames with a shift in timestamps with a different strategy for each data key (e.g. state, action or image)
|
||||||
|
If `None`, no shift is applied to current timestamp and the data from the current frame is loaded.
|
||||||
|
"""
|
||||||
|
|
||||||
available_datasets = ["pusht"]
|
available_datasets = ["pusht"]
|
||||||
|
fps = 10
|
||||||
|
image_keys = ["observation.image"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
version: str | None = "v1.2",
|
version: str | None = "v1.2",
|
||||||
batch_size: int | None = None,
|
|
||||||
*,
|
|
||||||
shuffle: bool = True,
|
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
pin_memory: bool = False,
|
transform: callable = None,
|
||||||
prefetch: int = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
sampler: Sampler | None = None,
|
|
||||||
collate_fn: Callable | None = None,
|
|
||||||
writer: Writer | None = None,
|
|
||||||
transform: "torchrl.envs.Transform" = None,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__()
|
||||||
dataset_id,
|
self.dataset_id = dataset_id
|
||||||
version,
|
self.version = version
|
||||||
batch_size,
|
self.root = root
|
||||||
shuffle=shuffle,
|
self.transform = transform
|
||||||
root=root,
|
self.delta_timestamps = delta_timestamps
|
||||||
pin_memory=pin_memory,
|
|
||||||
prefetch=prefetch,
|
data_dir = self.root / f"{self.dataset_id}"
|
||||||
sampler=sampler,
|
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
|
||||||
collate_fn=collate_fn,
|
self.data_dict = torch.load(data_dir / "data_dict.pth")
|
||||||
writer=writer,
|
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
|
||||||
transform=transform,
|
else:
|
||||||
|
self._download_and_preproc_obsolete()
|
||||||
|
data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
torch.save(self.data_dict, data_dir / "data_dict.pth")
|
||||||
|
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_samples(self) -> int:
|
||||||
|
return len(self.data_dict["index"])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_episodes(self) -> int:
|
||||||
|
return len(self.data_ids_per_episode)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
item = {}
|
||||||
|
|
||||||
|
# get episode id and timestamp of the sampled frame
|
||||||
|
current_ts = self.data_dict["timestamp"][idx].item()
|
||||||
|
episode = self.data_dict["episode"][idx].item()
|
||||||
|
|
||||||
|
for key in self.data_dict:
|
||||||
|
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
||||||
|
data, is_pad = load_data_with_delta_timestamps(
|
||||||
|
self.data_dict,
|
||||||
|
self.data_ids_per_episode,
|
||||||
|
self.delta_timestamps,
|
||||||
|
key,
|
||||||
|
current_ts,
|
||||||
|
episode,
|
||||||
)
|
)
|
||||||
|
item[key] = data
|
||||||
|
item[f"{key}_is_pad"] = is_pad
|
||||||
|
else:
|
||||||
|
item[key] = self.data_dict[key][idx]
|
||||||
|
|
||||||
|
if self.transform is not None:
|
||||||
|
item = self.transform(item)
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
def _download_and_preproc_obsolete(self):
|
def _download_and_preproc_obsolete(self):
|
||||||
assert self.root is not None
|
assert self.root is not None
|
||||||
|
@ -147,8 +185,10 @@ class PushtDataset(AbstractDataset):
|
||||||
states = torch.from_numpy(dataset_dict["state"])
|
states = torch.from_numpy(dataset_dict["state"])
|
||||||
actions = torch.from_numpy(dataset_dict["action"])
|
actions = torch.from_numpy(dataset_dict["action"])
|
||||||
|
|
||||||
|
self.data_ids_per_episode = {}
|
||||||
|
ep_dicts = []
|
||||||
|
|
||||||
idx0 = 0
|
idx0 = 0
|
||||||
idxtd = 0
|
|
||||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||||
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
||||||
# to create test artifact
|
# to create test artifact
|
||||||
|
@ -194,30 +234,45 @@ class PushtDataset(AbstractDataset):
|
||||||
# last step of demonstration is considered done
|
# last step of demonstration is considered done
|
||||||
done[-1] = True
|
done[-1] = True
|
||||||
|
|
||||||
ep_td = TensorDict(
|
ep_dict = {
|
||||||
{
|
"observation.image": image,
|
||||||
("observation", "image"): image[:-1],
|
"observation.state": agent_pos,
|
||||||
("observation", "state"): agent_pos[:-1],
|
"action": actions[idx0:idx1],
|
||||||
"action": actions[idx0:idx1][:-1],
|
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||||
"episode": episode_ids[idx0:idx1][:-1],
|
"frame_id": torch.arange(0, num_frames, 1),
|
||||||
"frame_id": torch.arange(0, num_frames - 1, 1),
|
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||||
("next", "observation", "image"): image[1:],
|
# "next.observation.image": image[1:],
|
||||||
("next", "observation", "state"): agent_pos[1:],
|
# "next.observation.state": agent_pos[1:],
|
||||||
# TODO: verify that reward and done are aligned with image and agent_pos
|
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
|
||||||
("next", "reward"): reward[1:],
|
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
||||||
("next", "done"): done[1:],
|
"next.done": torch.cat([done[1:], done[[-1]]]),
|
||||||
("next", "success"): success[1:],
|
"next.success": torch.cat([success[1:], success[[-1]]]),
|
||||||
},
|
}
|
||||||
batch_size=num_frames - 1,
|
ep_dicts.append(ep_dict)
|
||||||
)
|
|
||||||
|
|
||||||
if episode_id == 0:
|
assert isinstance(episode_id, int)
|
||||||
# hack to initialize tensordict data structure to store episodes
|
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
|
||||||
td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
|
assert len(self.data_ids_per_episode[episode_id]) == num_frames
|
||||||
|
|
||||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
|
||||||
|
|
||||||
idx0 = idx1
|
idx0 = idx1
|
||||||
idxtd = idxtd + len(ep_td)
|
|
||||||
|
|
||||||
return TensorStorage(td_data.lock_())
|
self.data_dict = {}
|
||||||
|
|
||||||
|
keys = ep_dicts[0].keys()
|
||||||
|
for key in keys:
|
||||||
|
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||||
|
|
||||||
|
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dataset = PushtDataset(
|
||||||
|
"pusht",
|
||||||
|
root=Path("data"),
|
||||||
|
delta_timestamps={
|
||||||
|
"observation.image": [0, -1, -0.2, -0.1],
|
||||||
|
"observation.state": [0, -1, -0.2, -0.1],
|
||||||
|
"action": [-0.1, 0, 1, 2, 3],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
dataset[10]
|
||||||
|
|
|
@ -1,75 +1,104 @@
|
||||||
import pickle
|
import pickle
|
||||||
import zipfile
|
import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchrl
|
|
||||||
import tqdm
|
import tqdm
|
||||||
from tensordict import TensorDict
|
|
||||||
from torchrl.data.replay_buffers.samplers import (
|
|
||||||
Sampler,
|
|
||||||
)
|
|
||||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
|
||||||
from torchrl.data.replay_buffers.writers import Writer
|
|
||||||
|
|
||||||
from lerobot.common.datasets.abstract import AbstractDataset
|
from lerobot.common.datasets.utils import load_data_with_delta_timestamps
|
||||||
|
|
||||||
|
|
||||||
def download():
|
def download(raw_dir):
|
||||||
raise NotImplementedError()
|
|
||||||
import gdown
|
import gdown
|
||||||
|
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||||
download_path = "data.zip"
|
zip_path = raw_dir / "data.zip"
|
||||||
gdown.download(url, download_path, quiet=False)
|
gdown.download(url, str(zip_path), quiet=False)
|
||||||
print("Extracting...")
|
print("Extracting...")
|
||||||
with zipfile.ZipFile(download_path, "r") as zip_f:
|
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
||||||
for member in zip_f.namelist():
|
for member in zip_f.namelist():
|
||||||
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
||||||
print(member)
|
print(member)
|
||||||
zip_f.extract(member=member)
|
zip_f.extract(member=member)
|
||||||
Path(download_path).unlink()
|
zip_path.unlink()
|
||||||
|
|
||||||
|
|
||||||
class SimxarmDataset(AbstractDataset):
|
class SimxarmDataset(torch.utils.data.Dataset):
|
||||||
available_datasets = [
|
available_datasets = [
|
||||||
"xarm_lift_medium",
|
"xarm_lift_medium",
|
||||||
]
|
]
|
||||||
|
fps = 15
|
||||||
|
image_keys = ["observation.image"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
version: str | None = "v1.1",
|
version: str | None = "v1.1",
|
||||||
batch_size: int | None = None,
|
|
||||||
*,
|
|
||||||
shuffle: bool = True,
|
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
pin_memory: bool = False,
|
transform: callable = None,
|
||||||
prefetch: int = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
sampler: Sampler | None = None,
|
|
||||||
collate_fn: Callable | None = None,
|
|
||||||
writer: Writer | None = None,
|
|
||||||
transform: "torchrl.envs.Transform" = None,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__()
|
||||||
dataset_id,
|
self.dataset_id = dataset_id
|
||||||
version,
|
self.version = version
|
||||||
batch_size,
|
self.root = root
|
||||||
shuffle=shuffle,
|
self.transform = transform
|
||||||
root=root,
|
self.delta_timestamps = delta_timestamps
|
||||||
pin_memory=pin_memory,
|
|
||||||
prefetch=prefetch,
|
data_dir = self.root / f"{self.dataset_id}"
|
||||||
sampler=sampler,
|
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
|
||||||
collate_fn=collate_fn,
|
self.data_dict = torch.load(data_dir / "data_dict.pth")
|
||||||
writer=writer,
|
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
|
||||||
transform=transform,
|
else:
|
||||||
|
self._download_and_preproc_obsolete()
|
||||||
|
data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
torch.save(self.data_dict, data_dir / "data_dict.pth")
|
||||||
|
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_samples(self) -> int:
|
||||||
|
return len(self.data_dict["index"])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_episodes(self) -> int:
|
||||||
|
return len(self.data_ids_per_episode)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
item = {}
|
||||||
|
|
||||||
|
# get episode id and timestamp of the sampled frame
|
||||||
|
current_ts = self.data_dict["timestamp"][idx].item()
|
||||||
|
episode = self.data_dict["episode"][idx].item()
|
||||||
|
|
||||||
|
for key in self.data_dict:
|
||||||
|
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
||||||
|
data, is_pad = load_data_with_delta_timestamps(
|
||||||
|
self.data_dict,
|
||||||
|
self.data_ids_per_episode,
|
||||||
|
self.delta_timestamps,
|
||||||
|
key,
|
||||||
|
current_ts,
|
||||||
|
episode,
|
||||||
)
|
)
|
||||||
|
item[key] = data
|
||||||
|
item[f"{key}_is_pad"] = is_pad
|
||||||
|
else:
|
||||||
|
item[key] = self.data_dict[key][idx]
|
||||||
|
|
||||||
|
if self.transform is not None:
|
||||||
|
item = self.transform(item)
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
def _download_and_preproc_obsolete(self):
|
def _download_and_preproc_obsolete(self):
|
||||||
# assert self.root is not None
|
assert self.root is not None
|
||||||
# TODO(rcadene): finish download
|
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||||
# download()
|
if not raw_dir.exists():
|
||||||
|
download(raw_dir)
|
||||||
|
|
||||||
dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl"
|
dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl"
|
||||||
print(f"Using offline dataset '{dataset_path}'")
|
print(f"Using offline dataset '{dataset_path}'")
|
||||||
|
@ -78,6 +107,9 @@ class SimxarmDataset(AbstractDataset):
|
||||||
|
|
||||||
total_frames = dataset_dict["actions"].shape[0]
|
total_frames = dataset_dict["actions"].shape[0]
|
||||||
|
|
||||||
|
self.data_ids_per_episode = {}
|
||||||
|
ep_dicts = []
|
||||||
|
|
||||||
idx0 = 0
|
idx0 = 0
|
||||||
idx1 = 0
|
idx1 = 0
|
||||||
episode_id = 0
|
episode_id = 0
|
||||||
|
@ -91,37 +123,38 @@ class SimxarmDataset(AbstractDataset):
|
||||||
|
|
||||||
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
||||||
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
||||||
next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
|
||||||
next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
# TODO(rcadene): concat the last "next_observations" to "observations"
|
||||||
|
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
||||||
|
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
||||||
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
||||||
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
|
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
|
||||||
|
|
||||||
episode = TensorDict(
|
ep_dict = {
|
||||||
{
|
"observation.image": image,
|
||||||
("observation", "image"): image,
|
"observation.state": state,
|
||||||
("observation", "state"): state,
|
"action": action,
|
||||||
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
|
|
||||||
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||||
"frame_id": torch.arange(0, num_frames, 1),
|
"frame_id": torch.arange(0, num_frames, 1),
|
||||||
("next", "observation", "image"): next_image,
|
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||||
("next", "observation", "state"): next_state,
|
# "next.observation.image": next_image,
|
||||||
("next", "reward"): next_reward,
|
# "next.observation.state": next_state,
|
||||||
("next", "done"): next_done,
|
"next.reward": next_reward,
|
||||||
},
|
"next.done": next_done,
|
||||||
batch_size=num_frames,
|
}
|
||||||
)
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
if episode_id == 0:
|
assert isinstance(episode_id, int)
|
||||||
# hack to initialize tensordict data structure to store episodes
|
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
|
||||||
td_data = (
|
assert len(self.data_ids_per_episode[episode_id]) == num_frames
|
||||||
episode[0]
|
|
||||||
.expand(total_frames)
|
|
||||||
.memmap_like(self.root / f"{self.dataset_id}" / "replay_buffer")
|
|
||||||
)
|
|
||||||
|
|
||||||
td_data[idx0:idx1] = episode
|
|
||||||
|
|
||||||
episode_id += 1
|
|
||||||
idx0 = idx1
|
idx0 = idx1
|
||||||
|
episode_id += 1
|
||||||
|
|
||||||
return TensorStorage(td_data.lock_())
|
self.data_dict = {}
|
||||||
|
|
||||||
|
keys = ep_dicts[0].keys()
|
||||||
|
for key in keys:
|
||||||
|
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||||
|
|
||||||
|
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
|
@ -3,6 +3,7 @@ import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,3 +29,71 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def euclidean_distance_matrix(mat0, mat1):
|
||||||
|
# Compute the square of the distance matrix
|
||||||
|
sq0 = torch.sum(mat0**2, dim=1, keepdim=True)
|
||||||
|
sq1 = torch.sum(mat1**2, dim=1, keepdim=True)
|
||||||
|
distance_sq = sq0 + sq1.transpose(0, 1) - 2 * mat0 @ mat1.transpose(0, 1)
|
||||||
|
|
||||||
|
# Taking the square root to get the euclidean distance
|
||||||
|
distance = torch.sqrt(torch.clamp(distance_sq, min=0))
|
||||||
|
return distance
|
||||||
|
|
||||||
|
|
||||||
|
def is_contiguously_true_or_false(bool_vector):
|
||||||
|
assert bool_vector.ndim == 1
|
||||||
|
assert bool_vector.dtype == torch.bool
|
||||||
|
|
||||||
|
# Compare each element with its neighbor to find changes
|
||||||
|
changes = bool_vector[1:] != bool_vector[:-1]
|
||||||
|
|
||||||
|
# Count the number of changes
|
||||||
|
num_changes = changes.sum().item()
|
||||||
|
|
||||||
|
# If there's more than one change, the list is not contiguous
|
||||||
|
return num_changes <= 1
|
||||||
|
|
||||||
|
# examples = [
|
||||||
|
# ([True, False, True, False, False, False], False),
|
||||||
|
# ([True, True, True, False, False, False], True),
|
||||||
|
# ([False, False, False, False, False, False], True)
|
||||||
|
# ]
|
||||||
|
# for bool_list, expected in examples:
|
||||||
|
# result = is_contiguously_true_or_false(bool_list)
|
||||||
|
|
||||||
|
|
||||||
|
def load_data_with_delta_timestamps(
|
||||||
|
data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode
|
||||||
|
):
|
||||||
|
# get indices of the frames associated to the episode, and their timestamps
|
||||||
|
ep_data_ids = data_ids_per_episode[episode]
|
||||||
|
ep_timestamps = data_dict["timestamp"][ep_data_ids]
|
||||||
|
|
||||||
|
# get timestamps used as query to retrieve data of previous/future frames
|
||||||
|
delta_ts = delta_timestamps[key]
|
||||||
|
query_ts = current_ts + torch.tensor(delta_ts)
|
||||||
|
|
||||||
|
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
|
||||||
|
dist = euclidean_distance_matrix(query_ts[:, None], ep_timestamps[:, None])
|
||||||
|
min_, argmin_ = dist.min(1)
|
||||||
|
|
||||||
|
# get the indices of the data that are closest to the query timestamps
|
||||||
|
data_ids = ep_data_ids[argmin_]
|
||||||
|
# closest_ts = ep_timestamps[argmin_]
|
||||||
|
|
||||||
|
# get the data
|
||||||
|
data = data_dict[key][data_ids].clone()
|
||||||
|
|
||||||
|
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
||||||
|
|
||||||
|
tol = 0.02
|
||||||
|
is_pad = min_ > tol
|
||||||
|
|
||||||
|
assert is_contiguously_true_or_false(is_pad), (
|
||||||
|
"One or several timestamps unexpectedly violate the tolerance."
|
||||||
|
"This might be due to synchronization issues with timestamps during data collection."
|
||||||
|
)
|
||||||
|
|
||||||
|
return data, is_pad
|
||||||
|
|
|
@ -1,64 +1,40 @@
|
||||||
from torchrl.envs import SerialEnv
|
import gymnasium as gym
|
||||||
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
|
|
||||||
|
|
||||||
|
|
||||||
def make_env(cfg, transform=None):
|
def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
||||||
"""
|
"""
|
||||||
Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying
|
Note: When `num_parallel_envs > 0`, this function returns a `SyncVectorEnv` which takes batched action as input and
|
||||||
environments. The env therefore returns batches.`
|
returns batched observation, reward, terminated, truncated of `num_parallel_envs` items.
|
||||||
"""
|
"""
|
||||||
|
kwargs = {}
|
||||||
kwargs = {
|
|
||||||
"frame_skip": cfg.env.action_repeat,
|
|
||||||
"from_pixels": cfg.env.from_pixels,
|
|
||||||
"pixels_only": cfg.env.pixels_only,
|
|
||||||
"image_size": cfg.env.image_size,
|
|
||||||
"num_prev_obs": cfg.n_obs_steps - 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.env.name == "simxarm":
|
if cfg.env.name == "simxarm":
|
||||||
from lerobot.common.envs.simxarm.env import SimxarmEnv
|
|
||||||
|
|
||||||
kwargs["task"] = cfg.env.task
|
kwargs["task"] = cfg.env.task
|
||||||
clsfunc = SimxarmEnv
|
|
||||||
elif cfg.env.name == "pusht":
|
elif cfg.env.name == "pusht":
|
||||||
from lerobot.common.envs.pusht.env import PushtEnv
|
import gym_pusht # noqa
|
||||||
|
|
||||||
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
|
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
|
||||||
|
kwargs.update(
|
||||||
clsfunc = PushtEnv
|
{
|
||||||
|
"obs_type": "pixels_agent_pos",
|
||||||
|
"render_action": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
env_fn = lambda: gym.make( # noqa: E731
|
||||||
|
"gym_pusht/PushTPixels-v0",
|
||||||
|
render_mode="rgb_array",
|
||||||
|
max_episode_steps=cfg.env.episode_length,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
elif cfg.env.name == "aloha":
|
elif cfg.env.name == "aloha":
|
||||||
from lerobot.common.envs.aloha.env import AlohaEnv
|
|
||||||
|
|
||||||
kwargs["task"] = cfg.env.task
|
kwargs["task"] = cfg.env.task
|
||||||
clsfunc = AlohaEnv
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.env.name)
|
raise ValueError(cfg.env.name)
|
||||||
|
|
||||||
def _make_env(seed):
|
if num_parallel_envs == 0:
|
||||||
nonlocal kwargs
|
# non-batched version of the env that returns an observation of shape (c)
|
||||||
kwargs["seed"] = seed
|
env = env_fn()
|
||||||
env = clsfunc(**kwargs)
|
|
||||||
|
|
||||||
# limit rollout to max_steps
|
|
||||||
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
|
|
||||||
|
|
||||||
if transform is not None:
|
|
||||||
# useful to add normalization
|
|
||||||
if isinstance(transform, Compose):
|
|
||||||
for tf in transform:
|
|
||||||
env.append_transform(tf.clone())
|
|
||||||
elif isinstance(transform, Transform):
|
|
||||||
env.append_transform(transform.clone())
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
# batched version of the env that returns an observation of shape (b, c)
|
||||||
|
env = gym.vector.SyncVectorEnv([env_fn for _ in range(num_parallel_envs)])
|
||||||
return env
|
return env
|
||||||
|
|
||||||
return SerialEnv(
|
|
||||||
cfg.rollout_batch_size,
|
|
||||||
create_env_fn=_make_env,
|
|
||||||
create_env_kwargs=[
|
|
||||||
{"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
|
@ -55,7 +55,7 @@ class SimxarmEnv(AbstractEnv):
|
||||||
if not _has_gym:
|
if not _has_gym:
|
||||||
raise ImportError("Cannot import gymnasium.")
|
raise ImportError("Cannot import gymnasium.")
|
||||||
|
|
||||||
import gymnasium
|
import gymnasium as gym
|
||||||
|
|
||||||
from lerobot.common.envs.simxarm.simxarm import TASKS
|
from lerobot.common.envs.simxarm.simxarm import TASKS
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ class SimxarmEnv(AbstractEnv):
|
||||||
self._env = TASKS[self.task]["env"]()
|
self._env = TASKS[self.task]["env"]()
|
||||||
|
|
||||||
num_actions = len(TASKS[self.task]["action_space"])
|
num_actions = len(TASKS[self.task]["action_space"])
|
||||||
self._action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
|
self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
|
||||||
self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32)
|
self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32)
|
||||||
if "w" not in TASKS[self.task]["action_space"]:
|
if "w" not in TASKS[self.task]["action_space"]:
|
||||||
self._action_padding[-1] = 1.0
|
self._action_padding[-1] = 1.0
|
||||||
|
|
|
@ -1,18 +1,20 @@
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from lerobot.common.policies.abstract import AbstractPolicy
|
|
||||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
|
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
|
||||||
|
from lerobot.common.policies.utils import populate_queues
|
||||||
from lerobot.common.utils import get_safe_torch_device
|
from lerobot.common.utils import get_safe_torch_device
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPolicy(AbstractPolicy):
|
class DiffusionPolicy(nn.Module):
|
||||||
name = "diffusion"
|
name = "diffusion"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -38,8 +40,12 @@ class DiffusionPolicy(AbstractPolicy):
|
||||||
# parameters passed to step
|
# parameters passed to step
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(n_action_steps)
|
super().__init__()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
self.n_obs_steps = n_obs_steps
|
||||||
|
self.n_action_steps = n_action_steps
|
||||||
|
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||||
|
self._queues = None
|
||||||
|
|
||||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||||
rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape)
|
rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape)
|
||||||
|
@ -100,76 +106,58 @@ class DiffusionPolicy(AbstractPolicy):
|
||||||
last_epoch=self.global_step - 1,
|
last_epoch=self.global_step - 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Clear observation and action queues. Should be called on `env.reset()`
|
||||||
|
"""
|
||||||
|
self._queues = {
|
||||||
|
"observation.image": deque(maxlen=self.n_obs_steps),
|
||||||
|
"observation.state": deque(maxlen=self.n_obs_steps),
|
||||||
|
"action": deque(maxlen=self.n_action_steps),
|
||||||
|
}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_actions(self, observation, step_count):
|
def select_action(self, batch, step):
|
||||||
"""
|
"""
|
||||||
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
|
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
|
||||||
"""
|
"""
|
||||||
# TODO(rcadene): remove unused step_count
|
# TODO(rcadene): remove unused step
|
||||||
del step_count
|
del step
|
||||||
|
assert "observation.image" in batch
|
||||||
|
assert "observation.state" in batch
|
||||||
|
assert len(batch) == 2
|
||||||
|
|
||||||
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
|
if len(self._queues["action"]) == 0:
|
||||||
|
# stack n latest observations from the queue
|
||||||
|
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||||
|
|
||||||
obs_dict = {
|
obs_dict = {
|
||||||
"image": observation["image"],
|
"image": batch["observation.image"],
|
||||||
"agent_pos": observation["state"],
|
"agent_pos": batch["observation.state"],
|
||||||
}
|
}
|
||||||
if self.training:
|
if self.training:
|
||||||
out = self.diffusion.predict_action(obs_dict)
|
out = self.diffusion.predict_action(obs_dict)
|
||||||
else:
|
else:
|
||||||
out = self.ema_diffusion.predict_action(obs_dict)
|
out = self.ema_diffusion.predict_action(obs_dict)
|
||||||
action = out["action"]
|
self._queues["action"].extend(out["action"].transpose(0, 1))
|
||||||
|
|
||||||
|
action = self._queues["action"].popleft()
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def update(self, replay_buffer, step):
|
def forward(self, batch, step):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
self.diffusion.train()
|
self.diffusion.train()
|
||||||
|
|
||||||
num_slices = self.cfg.batch_size
|
|
||||||
batch_size = self.cfg.horizon * num_slices
|
|
||||||
|
|
||||||
assert batch_size % self.cfg.horizon == 0
|
|
||||||
assert batch_size % num_slices == 0
|
|
||||||
|
|
||||||
def process_batch(batch, horizon, num_slices):
|
|
||||||
# trajectory t = 64, horizon h = 16
|
|
||||||
# (t h) ... -> t h ...
|
|
||||||
batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous()
|
|
||||||
|
|
||||||
# |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16
|
|
||||||
# |o|o| observations: 2
|
|
||||||
# | |a|a|a|a|a|a|a|a| actions executed: 8
|
|
||||||
# |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16
|
|
||||||
# note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model
|
|
||||||
|
|
||||||
image = batch["observation", "image"]
|
|
||||||
state = batch["observation", "state"]
|
|
||||||
action = batch["action"]
|
|
||||||
assert image.shape[1] == horizon
|
|
||||||
assert state.shape[1] == horizon
|
|
||||||
assert action.shape[1] == horizon
|
|
||||||
|
|
||||||
if not (horizon == 16 and self.cfg.n_obs_steps == 2):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
# keep first 2 observations of the slice corresponding to t=[-1,0]
|
|
||||||
image = image[:, : self.cfg.n_obs_steps]
|
|
||||||
state = state[:, : self.cfg.n_obs_steps]
|
|
||||||
|
|
||||||
out = {
|
|
||||||
"obs": {
|
|
||||||
"image": image.to(self.device, non_blocking=True),
|
|
||||||
"agent_pos": state.to(self.device, non_blocking=True),
|
|
||||||
},
|
|
||||||
"action": action.to(self.device, non_blocking=True),
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
|
|
||||||
batch = replay_buffer.sample(batch_size)
|
|
||||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
|
||||||
|
|
||||||
data_s = time.time() - start_time
|
data_s = time.time() - start_time
|
||||||
|
|
||||||
loss = self.diffusion.compute_loss(batch)
|
obs_dict = {
|
||||||
|
"image": batch["observation.image"],
|
||||||
|
"agent_pos": batch["observation.state"],
|
||||||
|
}
|
||||||
|
loss = self.diffusion.compute_loss(obs_dict)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
|
|
@ -1,53 +1,49 @@
|
||||||
from typing import Sequence
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tensordict import TensorDictBase
|
from torchvision.transforms.v2 import Compose, Transform
|
||||||
from tensordict.nn import dispatch
|
|
||||||
from tensordict.utils import NestedKey
|
|
||||||
from torchrl.envs.transforms import ObservationTransform, Transform
|
|
||||||
|
|
||||||
|
|
||||||
class Prod(ObservationTransform):
|
def apply_inverse_transform(item, transform):
|
||||||
|
transforms = transform.transforms if isinstance(transform, Compose) else [transform]
|
||||||
|
for tf in transforms[::-1]:
|
||||||
|
if tf.invertible:
|
||||||
|
item = tf.inverse_transform(item)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).")
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
class Prod(Transform):
|
||||||
invertible = True
|
invertible = True
|
||||||
|
|
||||||
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
|
def __init__(self, in_keys: list[str], prod: float):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_keys = in_keys
|
self.in_keys = in_keys
|
||||||
self.prod = prod
|
self.prod = prod
|
||||||
self.original_dtypes = {}
|
self.original_dtypes = {}
|
||||||
|
|
||||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
def forward(self, item):
|
||||||
# _reset is called once when the environment reset to normalize the first observation
|
|
||||||
tensordict_reset = self._call(tensordict_reset)
|
|
||||||
return tensordict_reset
|
|
||||||
|
|
||||||
@dispatch(source="in_keys", dest="out_keys")
|
|
||||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
||||||
return self._call(tensordict)
|
|
||||||
|
|
||||||
def _call(self, td):
|
|
||||||
for key in self.in_keys:
|
for key in self.in_keys:
|
||||||
if td.get(key, None) is None:
|
if key not in item:
|
||||||
continue
|
continue
|
||||||
self.original_dtypes[key] = td[key].dtype
|
self.original_dtypes[key] = item[key].dtype
|
||||||
td[key] = td[key].type(torch.float32) * self.prod
|
item[key] = item[key].type(torch.float32) * self.prod
|
||||||
return td
|
return item
|
||||||
|
|
||||||
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
def inverse_transform(self, item):
|
||||||
for key in self.in_keys:
|
for key in self.in_keys:
|
||||||
if td.get(key, None) is None:
|
if key not in item:
|
||||||
continue
|
continue
|
||||||
td[key] = (td[key] / self.prod).type(self.original_dtypes[key])
|
item[key] = (item[key] / self.prod).type(self.original_dtypes[key])
|
||||||
return td
|
return item
|
||||||
|
|
||||||
def transform_observation_spec(self, obs_spec):
|
# def transform_observation_spec(self, obs_spec):
|
||||||
for key in self.in_keys:
|
# for key in self.in_keys:
|
||||||
if obs_spec.get(key, None) is None:
|
# if obs_spec.get(key, None) is None:
|
||||||
continue
|
# continue
|
||||||
obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
|
# obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
|
||||||
obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
|
# obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
|
||||||
obs_spec[key].dtype = torch.float32
|
# obs_spec[key].dtype = torch.float32
|
||||||
return obs_spec
|
# return obs_spec
|
||||||
|
|
||||||
|
|
||||||
class NormalizeTransform(Transform):
|
class NormalizeTransform(Transform):
|
||||||
|
@ -55,65 +51,50 @@ class NormalizeTransform(Transform):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
stats: TensorDictBase,
|
stats: dict,
|
||||||
in_keys: Sequence[NestedKey] = None,
|
in_keys: list[str] = None,
|
||||||
out_keys: Sequence[NestedKey] | None = None,
|
out_keys: list[str] | None = None,
|
||||||
in_keys_inv: Sequence[NestedKey] | None = None,
|
in_keys_inv: list[str] | None = None,
|
||||||
out_keys_inv: Sequence[NestedKey] | None = None,
|
out_keys_inv: list[str] | None = None,
|
||||||
mode="mean_std",
|
mode="mean_std",
|
||||||
):
|
):
|
||||||
if out_keys is None:
|
super().__init__()
|
||||||
out_keys = in_keys
|
self.in_keys = in_keys
|
||||||
if in_keys_inv is None:
|
self.out_keys = in_keys if out_keys is None else out_keys
|
||||||
in_keys_inv = out_keys
|
self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv
|
||||||
if out_keys_inv is None:
|
self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv
|
||||||
out_keys_inv = in_keys
|
|
||||||
super().__init__(
|
|
||||||
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
|
|
||||||
)
|
|
||||||
self.stats = stats
|
self.stats = stats
|
||||||
assert mode in ["mean_std", "min_max"]
|
assert mode in ["mean_std", "min_max"]
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
|
||||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
def forward(self, item):
|
||||||
# _reset is called once when the environment reset to normalize the first observation
|
|
||||||
tensordict_reset = self._call(tensordict_reset)
|
|
||||||
return tensordict_reset
|
|
||||||
|
|
||||||
@dispatch(source="in_keys", dest="out_keys")
|
|
||||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
||||||
return self._call(tensordict)
|
|
||||||
|
|
||||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
|
||||||
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
||||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
if inkey not in item:
|
||||||
if td.get(inkey, None) is None:
|
|
||||||
continue
|
continue
|
||||||
if self.mode == "mean_std":
|
if self.mode == "mean_std":
|
||||||
mean = self.stats[inkey]["mean"]
|
mean = self.stats[inkey]["mean"]
|
||||||
std = self.stats[inkey]["std"]
|
std = self.stats[inkey]["std"]
|
||||||
td[outkey] = (td[inkey] - mean) / (std + 1e-8)
|
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
|
||||||
else:
|
else:
|
||||||
min = self.stats[inkey]["min"]
|
min = self.stats[inkey]["min"]
|
||||||
max = self.stats[inkey]["max"]
|
max = self.stats[inkey]["max"]
|
||||||
# normalize to [0,1]
|
# normalize to [0,1]
|
||||||
td[outkey] = (td[inkey] - min) / (max - min)
|
item[outkey] = (item[inkey] - min) / (max - min)
|
||||||
# normalize to [-1, 1]
|
# normalize to [-1, 1]
|
||||||
td[outkey] = td[outkey] * 2 - 1
|
item[outkey] = item[outkey] * 2 - 1
|
||||||
return td
|
return item
|
||||||
|
|
||||||
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
def inverse_transform(self, item):
|
||||||
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
|
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
|
||||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
if inkey not in item:
|
||||||
if td.get(inkey, None) is None:
|
|
||||||
continue
|
continue
|
||||||
if self.mode == "mean_std":
|
if self.mode == "mean_std":
|
||||||
mean = self.stats[inkey]["mean"]
|
mean = self.stats[inkey]["mean"]
|
||||||
std = self.stats[inkey]["std"]
|
std = self.stats[inkey]["std"]
|
||||||
td[outkey] = td[inkey] * std + mean
|
item[outkey] = item[inkey] * std + mean
|
||||||
else:
|
else:
|
||||||
min = self.stats[inkey]["min"]
|
min = self.stats[inkey]["min"]
|
||||||
max = self.stats[inkey]["max"]
|
max = self.stats[inkey]["max"]
|
||||||
td[outkey] = (td[inkey] + 1) / 2
|
item[outkey] = (item[inkey] + 1) / 2
|
||||||
td[outkey] = td[outkey] * (max - min) + min
|
item[outkey] = item[outkey] * (max - min) + min
|
||||||
return td
|
return item
|
||||||
|
|
|
@ -36,92 +36,177 @@ from datetime import datetime as dt
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
import gymnasium as gym
|
||||||
|
import hydra
|
||||||
import imageio
|
import imageio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from tensordict.nn import TensorDictModule
|
|
||||||
from torchrl.envs import EnvBase
|
|
||||||
from torchrl.envs.batched_envs import BatchedEnvBase
|
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import log_output_dir
|
from lerobot.common.logger import log_output_dir
|
||||||
from lerobot.common.policies.abstract import AbstractPolicy
|
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||||
|
from lerobot.common.transforms import apply_inverse_transform
|
||||||
|
|
||||||
|
|
||||||
def write_video(video_path, stacked_frames, fps):
|
def write_video(video_path, stacked_frames, fps):
|
||||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_observation(observation, transform=None):
|
||||||
|
# map to expected inputs for the policy
|
||||||
|
obs = {
|
||||||
|
"observation.image": torch.from_numpy(observation["pixels"]).float(),
|
||||||
|
"observation.state": torch.from_numpy(observation["agent_pos"]).float(),
|
||||||
|
}
|
||||||
|
# convert to (b c h w) torch format
|
||||||
|
obs["observation.image"] = einops.rearrange(obs["observation.image"], "b h w c -> b c h w")
|
||||||
|
|
||||||
|
# apply same transforms as in training
|
||||||
|
if transform is not None:
|
||||||
|
for key in obs:
|
||||||
|
obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]])
|
||||||
|
|
||||||
|
return obs
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess_action(action, transform=None):
|
||||||
|
action = action.to("cpu")
|
||||||
|
# action is a batch (num_env,action_dim) instead of an item (action_dim),
|
||||||
|
# we assume applying inverse transform on a batch works the same
|
||||||
|
action = apply_inverse_transform({"action": action}, transform)["action"].numpy()
|
||||||
|
assert (
|
||||||
|
action.ndim == 2
|
||||||
|
), "we assume dimensions are respectively the number of parallel envs, action dimensions"
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
def eval_policy(
|
def eval_policy(
|
||||||
env: BatchedEnvBase,
|
env: gym.vector.VectorEnv,
|
||||||
policy: AbstractPolicy,
|
policy,
|
||||||
num_episodes: int = 10,
|
|
||||||
max_steps: int = 30,
|
|
||||||
save_video: bool = False,
|
save_video: bool = False,
|
||||||
video_dir: Path = None,
|
video_dir: Path = None,
|
||||||
|
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
||||||
fps: int = 15,
|
fps: int = 15,
|
||||||
return_first_video: bool = False,
|
return_first_video: bool = False,
|
||||||
|
transform: callable = None,
|
||||||
):
|
):
|
||||||
if policy is not None:
|
if policy is not None:
|
||||||
policy.eval()
|
policy.eval()
|
||||||
start = time.time()
|
start = time.time()
|
||||||
sum_rewards = []
|
sum_rewards = []
|
||||||
max_rewards = []
|
max_rewards = []
|
||||||
successes = []
|
all_successes = []
|
||||||
seeds = []
|
seeds = []
|
||||||
threads = [] # for video saving threads
|
threads = [] # for video saving threads
|
||||||
episode_counter = 0 # for saving the correct number of videos
|
episode_counter = 0 # for saving the correct number of videos
|
||||||
|
|
||||||
|
num_episodes = len(env.envs)
|
||||||
|
|
||||||
# TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than
|
# TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than
|
||||||
# needed as I'm currently taking a ceil.
|
# needed as I'm currently taking a ceil.
|
||||||
for i in tqdm.tqdm(range(-(-num_episodes // env.batch_size[0]))):
|
|
||||||
ep_frames = []
|
ep_frames = []
|
||||||
|
|
||||||
def maybe_render_frame(env: EnvBase, _):
|
def maybe_render_frame(env):
|
||||||
if save_video or (return_first_video and i == 0): # noqa: B023
|
if save_video: # noqa: B023
|
||||||
ep_frames.append(env.render()) # noqa: B023
|
if return_first_video:
|
||||||
|
visu = env.envs[0].render()
|
||||||
|
visu = visu[None, ...] # add batch dim
|
||||||
|
else:
|
||||||
|
visu = np.stack([env.render() for env in env.envs])
|
||||||
|
ep_frames.append(visu) # noqa: B023
|
||||||
|
|
||||||
# Clear the policy's action queue before the start of a new rollout.
|
for _ in range(num_episodes):
|
||||||
if policy is not None:
|
seeds.append("TODO")
|
||||||
policy.clear_action_queue()
|
|
||||||
|
|
||||||
if env.is_closed:
|
if hasattr(policy, "reset"):
|
||||||
env.start() # needed to be able to get the seeds the first time as BatchedEnvs are lazy
|
policy.reset()
|
||||||
seeds.extend(env._next_seed)
|
else:
|
||||||
with torch.inference_mode():
|
logging.warning(
|
||||||
# TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all
|
f"Policy {policy} doesnt have a `reset` method. This find if the policy doesnt rely on an internal state during rollout."
|
||||||
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
|
|
||||||
rollout = env.rollout(
|
|
||||||
max_steps=max_steps,
|
|
||||||
policy=policy,
|
|
||||||
auto_cast_to_device=True,
|
|
||||||
callback=maybe_render_frame,
|
|
||||||
break_when_any_done=env.batch_size[0] == 1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# reset the environment
|
||||||
|
observation, info = env.reset(seed=cfg.seed)
|
||||||
|
maybe_render_frame(env)
|
||||||
|
|
||||||
|
rewards = []
|
||||||
|
successes = []
|
||||||
|
dones = []
|
||||||
|
|
||||||
|
done = torch.tensor([False for _ in env.envs])
|
||||||
|
step = 0
|
||||||
|
do_rollout = True
|
||||||
|
while do_rollout:
|
||||||
|
# apply transform to normalize the observations
|
||||||
|
observation = preprocess_observation(observation, transform)
|
||||||
|
|
||||||
|
# send observation to device/gpu
|
||||||
|
observation = {key: observation[key].to(cfg.device, non_blocking=True) for key in observation}
|
||||||
|
|
||||||
|
# get the next action for the environment
|
||||||
|
with torch.inference_mode():
|
||||||
|
action = policy.select_action(observation, step)
|
||||||
|
|
||||||
|
# apply inverse transform to unnormalize the action
|
||||||
|
action = postprocess_action(action, transform)
|
||||||
|
|
||||||
|
# apply the next
|
||||||
|
observation, reward, terminated, truncated, info = env.step(action)
|
||||||
|
maybe_render_frame(env)
|
||||||
|
|
||||||
|
# TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?)
|
||||||
|
reward = torch.from_numpy(reward)
|
||||||
|
terminated = torch.from_numpy(terminated)
|
||||||
|
truncated = torch.from_numpy(truncated)
|
||||||
|
# environment is considered done (no more steps), when success state is reached (terminated is True),
|
||||||
|
# or time limit is reached (truncated is True), or it was previsouly done.
|
||||||
|
done = terminated | truncated | done
|
||||||
|
|
||||||
|
if "final_info" in info:
|
||||||
|
# VectorEnv stores is_success into `info["final_info"][env_id]["is_success"]` instead of `info["is_success"]`
|
||||||
|
success = [
|
||||||
|
env_info["is_success"] if env_info is not None else False for env_info in info["final_info"]
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
success = [False for _ in env.envs]
|
||||||
|
success = torch.tensor(success)
|
||||||
|
|
||||||
|
rewards.append(reward)
|
||||||
|
dones.append(done)
|
||||||
|
successes.append(success)
|
||||||
|
|
||||||
|
step += 1
|
||||||
|
|
||||||
|
if done.all():
|
||||||
|
do_rollout = False
|
||||||
|
break
|
||||||
|
|
||||||
|
rewards = torch.stack(rewards, dim=1)
|
||||||
|
successes = torch.stack(successes, dim=1)
|
||||||
|
dones = torch.stack(dones, dim=1)
|
||||||
|
|
||||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||||
# this won't be included).
|
# this won't be included).
|
||||||
# Note: this assumes that the shape of the done key is (batch_size, max_steps, 1).
|
# Note: this assumes that the shape of the done key is (batch_size, max_steps).
|
||||||
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
||||||
rollout_steps = rollout["next", "done"].shape[1]
|
done_indices = torch.argmax(dones.to(int), axis=1) # (batch_size, rollout_steps)
|
||||||
done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps)
|
expand_done_indices = done_indices[:, None].expand(-1, step)
|
||||||
mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1)
|
expand_step_indices = torch.arange(step)[None, :].expand(num_episodes, -1)
|
||||||
batch_sum_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "sum")
|
mask = (expand_step_indices <= expand_done_indices).int() # (batch_size, rollout_steps)
|
||||||
batch_max_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "max")
|
batch_sum_reward = einops.reduce((rewards * mask), "b n -> b", "sum")
|
||||||
batch_success = einops.reduce((rollout["next", "success"] * mask), "b n 1 -> b", "any")
|
batch_max_reward = einops.reduce((rewards * mask), "b n -> b", "max")
|
||||||
|
batch_success = einops.reduce((successes * mask), "b n -> b", "any")
|
||||||
sum_rewards.extend(batch_sum_reward.tolist())
|
sum_rewards.extend(batch_sum_reward.tolist())
|
||||||
max_rewards.extend(batch_max_reward.tolist())
|
max_rewards.extend(batch_max_reward.tolist())
|
||||||
successes.extend(batch_success.tolist())
|
all_successes.extend(batch_success.tolist())
|
||||||
|
|
||||||
if save_video or (return_first_video and i == 0):
|
env.close()
|
||||||
batch_stacked_frames = np.stack(ep_frames) # (t, b, *)
|
|
||||||
batch_stacked_frames = batch_stacked_frames.transpose(
|
if save_video or return_first_video:
|
||||||
1, 0, *range(2, batch_stacked_frames.ndim)
|
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||||
) # (b, t, *)
|
|
||||||
|
|
||||||
if save_video:
|
if save_video:
|
||||||
for stacked_frames, done_index in zip(
|
for stacked_frames, done_index in zip(
|
||||||
|
@ -139,7 +224,7 @@ def eval_policy(
|
||||||
threads.append(thread)
|
threads.append(thread)
|
||||||
episode_counter += 1
|
episode_counter += 1
|
||||||
|
|
||||||
if return_first_video and i == 0:
|
if return_first_video:
|
||||||
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
|
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
|
||||||
|
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
|
@ -158,16 +243,16 @@ def eval_policy(
|
||||||
zip(
|
zip(
|
||||||
sum_rewards[:num_episodes],
|
sum_rewards[:num_episodes],
|
||||||
max_rewards[:num_episodes],
|
max_rewards[:num_episodes],
|
||||||
successes[:num_episodes],
|
all_successes[:num_episodes],
|
||||||
seeds[:num_episodes],
|
seeds[:num_episodes],
|
||||||
strict=True,
|
strict=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
"aggregated": {
|
"aggregated": {
|
||||||
"avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]),
|
"avg_sum_reward": float(np.nanmean(sum_rewards[:num_episodes])),
|
||||||
"avg_max_reward": np.nanmean(max_rewards[:num_episodes]),
|
"avg_max_reward": float(np.nanmean(max_rewards[:num_episodes])),
|
||||||
"pc_success": np.nanmean(successes[:num_episodes]) * 100,
|
"pc_success": float(np.nanmean(all_successes[:num_episodes]) * 100),
|
||||||
"eval_s": time.time() - start,
|
"eval_s": time.time() - start,
|
||||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||||
},
|
},
|
||||||
|
@ -194,21 +279,13 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
|
|
||||||
logging.info("Making transforms.")
|
logging.info("Making transforms.")
|
||||||
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
||||||
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path)
|
dataset = make_dataset(cfg, stats_path=stats_path)
|
||||||
|
|
||||||
logging.info("Making environment.")
|
logging.info("Making environment.")
|
||||||
env = make_env(cfg, transform=offline_buffer.transform)
|
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||||
|
|
||||||
if cfg.policy.pretrained_model_path:
|
|
||||||
policy = make_policy(cfg)
|
|
||||||
policy = TensorDictModule(
|
|
||||||
policy,
|
|
||||||
in_keys=["observation", "step_count"],
|
|
||||||
out_keys=["action"],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# when policy is None, rollout a random policy
|
# when policy is None, rollout a random policy
|
||||||
policy = None
|
policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None
|
||||||
|
|
||||||
info = eval_policy(
|
info = eval_policy(
|
||||||
env,
|
env,
|
||||||
|
@ -216,8 +293,8 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
save_video=True,
|
save_video=True,
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
fps=cfg.env.fps,
|
fps=cfg.env.fps,
|
||||||
max_steps=cfg.env.episode_length,
|
# TODO(rcadene): what should we do with the transform?
|
||||||
num_episodes=cfg.eval_episodes,
|
transform=dataset.transform,
|
||||||
)
|
)
|
||||||
print(info["aggregated"])
|
print(info["aggregated"])
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,12 @@
|
||||||
import logging
|
import logging
|
||||||
|
from itertools import cycle
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tensordict.nn import TensorDictModule
|
|
||||||
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
|
||||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import Logger, log_output_dir
|
from lerobot.common.logger import Logger, log_output_dir
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
@ -34,7 +32,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa
|
||||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||||
|
|
||||||
|
|
||||||
def log_train_info(logger, info, step, cfg, offline_buffer, is_offline):
|
def log_train_info(logger, info, step, cfg, dataset, is_offline):
|
||||||
loss = info["loss"]
|
loss = info["loss"]
|
||||||
grad_norm = info["grad_norm"]
|
grad_norm = info["grad_norm"]
|
||||||
lr = info["lr"]
|
lr = info["lr"]
|
||||||
|
@ -44,9 +42,9 @@ def log_train_info(logger, info, step, cfg, offline_buffer, is_offline):
|
||||||
# A sample is an (observation,action) pair, where observation and action
|
# A sample is an (observation,action) pair, where observation and action
|
||||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||||
num_samples = (step + 1) * cfg.policy.batch_size
|
num_samples = (step + 1) * cfg.policy.batch_size
|
||||||
avg_samples_per_ep = offline_buffer.num_samples / offline_buffer.num_episodes
|
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
|
||||||
num_episodes = num_samples / avg_samples_per_ep
|
num_episodes = num_samples / avg_samples_per_ep
|
||||||
num_epochs = num_samples / offline_buffer.num_samples
|
num_epochs = num_samples / dataset.num_samples
|
||||||
log_items = [
|
log_items = [
|
||||||
f"step:{format_big_number(step)}",
|
f"step:{format_big_number(step)}",
|
||||||
# number of samples seen during training
|
# number of samples seen during training
|
||||||
|
@ -73,7 +71,7 @@ def log_train_info(logger, info, step, cfg, offline_buffer, is_offline):
|
||||||
logger.log_dict(info, step, mode="train")
|
logger.log_dict(info, step, mode="train")
|
||||||
|
|
||||||
|
|
||||||
def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline):
|
def log_eval_info(logger, info, step, cfg, dataset, is_offline):
|
||||||
eval_s = info["eval_s"]
|
eval_s = info["eval_s"]
|
||||||
avg_sum_reward = info["avg_sum_reward"]
|
avg_sum_reward = info["avg_sum_reward"]
|
||||||
pc_success = info["pc_success"]
|
pc_success = info["pc_success"]
|
||||||
|
@ -81,9 +79,9 @@ def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline):
|
||||||
# A sample is an (observation,action) pair, where observation and action
|
# A sample is an (observation,action) pair, where observation and action
|
||||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||||
num_samples = (step + 1) * cfg.policy.batch_size
|
num_samples = (step + 1) * cfg.policy.batch_size
|
||||||
avg_samples_per_ep = offline_buffer.num_samples / offline_buffer.num_episodes
|
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
|
||||||
num_episodes = num_samples / avg_samples_per_ep
|
num_episodes = num_samples / avg_samples_per_ep
|
||||||
num_epochs = num_samples / offline_buffer.num_samples
|
num_epochs = num_samples / dataset.num_samples
|
||||||
log_items = [
|
log_items = [
|
||||||
f"step:{format_big_number(step)}",
|
f"step:{format_big_number(step)}",
|
||||||
# number of samples seen during training
|
# number of samples seen during training
|
||||||
|
@ -124,30 +122,30 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
set_global_seed(cfg.seed)
|
set_global_seed(cfg.seed)
|
||||||
|
|
||||||
logging.info("make_offline_buffer")
|
logging.info("make_dataset")
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
dataset = make_dataset(cfg)
|
||||||
|
|
||||||
# TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy
|
# TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy
|
||||||
if cfg.policy.balanced_sampling:
|
# if cfg.policy.balanced_sampling:
|
||||||
logging.info("make online_buffer")
|
# logging.info("make online_buffer")
|
||||||
num_traj_per_batch = cfg.policy.batch_size
|
# num_traj_per_batch = cfg.policy.batch_size
|
||||||
|
|
||||||
online_sampler = PrioritizedSliceSampler(
|
# online_sampler = PrioritizedSliceSampler(
|
||||||
max_capacity=100_000,
|
# max_capacity=100_000,
|
||||||
alpha=cfg.policy.per_alpha,
|
# alpha=cfg.policy.per_alpha,
|
||||||
beta=cfg.policy.per_beta,
|
# beta=cfg.policy.per_beta,
|
||||||
num_slices=num_traj_per_batch,
|
# num_slices=num_traj_per_batch,
|
||||||
strict_length=True,
|
# strict_length=True,
|
||||||
)
|
# )
|
||||||
|
|
||||||
online_buffer = TensorDictReplayBuffer(
|
# online_buffer = TensorDictReplayBuffer(
|
||||||
storage=LazyMemmapStorage(100_000),
|
# storage=LazyMemmapStorage(100_000),
|
||||||
sampler=online_sampler,
|
# sampler=online_sampler,
|
||||||
transform=offline_buffer.transform,
|
# transform=dataset.transform,
|
||||||
)
|
# )
|
||||||
|
|
||||||
logging.info("make_env")
|
logging.info("make_env")
|
||||||
env = make_env(cfg, transform=offline_buffer.transform)
|
env = make_env(cfg)
|
||||||
|
|
||||||
logging.info("make_policy")
|
logging.info("make_policy")
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
|
@ -155,8 +153,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
|
||||||
td_policy = TensorDictModule(policy, in_keys=["observation", "step_count"], out_keys=["action"])
|
|
||||||
|
|
||||||
# log metrics to terminal and wandb
|
# log metrics to terminal and wandb
|
||||||
logger = Logger(out_dir, job_name, cfg)
|
logger = Logger(out_dir, job_name, cfg)
|
||||||
|
|
||||||
|
@ -165,8 +161,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
|
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
|
||||||
logging.info(f"{cfg.online_steps=}")
|
logging.info(f"{cfg.online_steps=}")
|
||||||
logging.info(f"{cfg.env.action_repeat=}")
|
logging.info(f"{cfg.env.action_repeat=}")
|
||||||
logging.info(f"{offline_buffer.num_samples=} ({format_big_number(offline_buffer.num_samples)})")
|
logging.info(f"{dataset.num_samples=} ({format_big_number(dataset.num_samples)})")
|
||||||
logging.info(f"{offline_buffer.num_episodes=}")
|
logging.info(f"{dataset.num_episodes=}")
|
||||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||||
|
|
||||||
|
@ -176,14 +172,15 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
eval_info, first_video = eval_policy(
|
eval_info, first_video = eval_policy(
|
||||||
env,
|
env,
|
||||||
td_policy,
|
policy,
|
||||||
num_episodes=cfg.eval_episodes,
|
num_episodes=cfg.eval_episodes,
|
||||||
max_steps=cfg.env.episode_length,
|
max_steps=cfg.env.episode_length,
|
||||||
return_first_video=True,
|
return_first_video=True,
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
save_video=True,
|
save_video=True,
|
||||||
|
transform=dataset.transform,
|
||||||
)
|
)
|
||||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_buffer, is_offline)
|
log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline)
|
||||||
if cfg.wandb.enable:
|
if cfg.wandb.enable:
|
||||||
logger.log_video(first_video, step, mode="eval")
|
logger.log_video(first_video, step, mode="eval")
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
@ -196,14 +193,29 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
step = 0 # number of policy update (forward + backward + optim)
|
step = 0 # number of policy update (forward + backward + optim)
|
||||||
|
|
||||||
is_offline = True
|
is_offline = True
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
num_workers=4,
|
||||||
|
batch_size=cfg.policy.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
pin_memory=cfg.device != "cpu",
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
dl_iter = cycle(dataloader)
|
||||||
for offline_step in range(cfg.offline_steps):
|
for offline_step in range(cfg.offline_steps):
|
||||||
if offline_step == 0:
|
if offline_step == 0:
|
||||||
logging.info("Start offline training on a fixed dataset")
|
logging.info("Start offline training on a fixed dataset")
|
||||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
|
||||||
policy.train()
|
policy.train()
|
||||||
train_info = policy.update(offline_buffer, step)
|
batch = next(dl_iter)
|
||||||
|
|
||||||
|
for key in batch:
|
||||||
|
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||||
|
|
||||||
|
train_info = policy.update(batch, step)
|
||||||
|
|
||||||
|
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||||
if step % cfg.log_freq == 0:
|
if step % cfg.log_freq == 0:
|
||||||
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
|
log_train_info(logger, train_info, step, cfg, dataset, is_offline)
|
||||||
|
|
||||||
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
|
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
|
||||||
# step + 1.
|
# step + 1.
|
||||||
|
@ -211,7 +223,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None
|
demo_buffer = dataset if cfg.policy.balanced_sampling else None
|
||||||
online_step = 0
|
online_step = 0
|
||||||
is_offline = False
|
is_offline = False
|
||||||
for env_step in range(cfg.online_steps):
|
for env_step in range(cfg.online_steps):
|
||||||
|
@ -221,7 +233,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
rollout = env.rollout(
|
rollout = env.rollout(
|
||||||
max_steps=cfg.env.episode_length,
|
max_steps=cfg.env.episode_length,
|
||||||
policy=td_policy,
|
policy=policy,
|
||||||
auto_cast_to_device=True,
|
auto_cast_to_device=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -242,7 +254,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
# set same episode index for all time steps contained in this rollout
|
# set same episode index for all time steps contained in this rollout
|
||||||
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
|
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
|
||||||
online_buffer.extend(rollout)
|
# online_buffer.extend(rollout)
|
||||||
|
|
||||||
ep_sum_reward = rollout["next", "reward"].sum()
|
ep_sum_reward = rollout["next", "reward"].sum()
|
||||||
ep_max_reward = rollout["next", "reward"].max()
|
ep_max_reward = rollout["next", "reward"].max()
|
||||||
|
@ -257,13 +269,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
for _ in range(cfg.policy.utd):
|
for _ in range(cfg.policy.utd):
|
||||||
train_info = policy.update(
|
train_info = policy.update(
|
||||||
online_buffer,
|
# online_buffer,
|
||||||
step,
|
step,
|
||||||
demo_buffer=demo_buffer,
|
demo_buffer=demo_buffer,
|
||||||
)
|
)
|
||||||
if step % cfg.log_freq == 0:
|
if step % cfg.log_freq == 0:
|
||||||
train_info.update(rollout_info)
|
train_info.update(rollout_info)
|
||||||
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
|
log_train_info(logger, train_info, step, cfg, dataset, is_offline)
|
||||||
|
|
||||||
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
|
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
|
||||||
# in step + 1.
|
# in step + 1.
|
||||||
|
|
|
@ -10,7 +10,7 @@ from torchrl.data.replay_buffers import (
|
||||||
SamplerWithoutReplacement,
|
SamplerWithoutReplacement,
|
||||||
)
|
)
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.logger import log_output_dir
|
from lerobot.common.logger import log_output_dir
|
||||||
from lerobot.common.utils import init_logging
|
from lerobot.common.utils import init_logging
|
||||||
|
|
||||||
|
@ -44,8 +44,8 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("make_offline_buffer")
|
logging.info("make_dataset")
|
||||||
offline_buffer = make_offline_buffer(
|
dataset = make_dataset(
|
||||||
cfg,
|
cfg,
|
||||||
overwrite_sampler=sampler,
|
overwrite_sampler=sampler,
|
||||||
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
|
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
|
||||||
|
@ -55,12 +55,12 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Start rendering episodes from offline buffer")
|
logging.info("Start rendering episodes from offline buffer")
|
||||||
video_paths = render_dataset(offline_buffer, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
|
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
|
||||||
for video_path in video_paths:
|
for video_path in video_paths:
|
||||||
logging.info(video_path)
|
logging.info(video_path)
|
||||||
|
|
||||||
|
|
||||||
def render_dataset(offline_buffer, out_dir, max_num_samples, fps):
|
def render_dataset(dataset, out_dir, max_num_samples, fps):
|
||||||
out_dir = Path(out_dir)
|
out_dir = Path(out_dir)
|
||||||
video_paths = []
|
video_paths = []
|
||||||
threads = []
|
threads = []
|
||||||
|
@ -69,17 +69,17 @@ def render_dataset(offline_buffer, out_dir, max_num_samples, fps):
|
||||||
logging.info(f"Visualizing episode {current_ep_idx}")
|
logging.info(f"Visualizing episode {current_ep_idx}")
|
||||||
for i in range(max_num_samples):
|
for i in range(max_num_samples):
|
||||||
# TODO(rcadene): make it work with bsize > 1
|
# TODO(rcadene): make it work with bsize > 1
|
||||||
ep_td = offline_buffer.sample(1)
|
ep_td = dataset.sample(1)
|
||||||
ep_idx = ep_td["episode"][FIRST_FRAME].item()
|
ep_idx = ep_td["episode"][FIRST_FRAME].item()
|
||||||
|
|
||||||
# TODO(rcadene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
|
# TODO(rcadene): modify dataset._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
|
||||||
num_frames_left = offline_buffer._sampler._sample_list.numel()
|
num_frames_left = dataset._sampler._sample_list.numel()
|
||||||
episode_is_done = ep_idx != current_ep_idx
|
episode_is_done = ep_idx != current_ep_idx
|
||||||
|
|
||||||
if episode_is_done:
|
if episode_is_done:
|
||||||
logging.info(f"Rendering episode {current_ep_idx}")
|
logging.info(f"Rendering episode {current_ep_idx}")
|
||||||
|
|
||||||
for im_key in offline_buffer.image_keys:
|
for im_key in dataset.image_keys:
|
||||||
if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1):
|
if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1):
|
||||||
# when first frame of episode, initialize frames dict
|
# when first frame of episode, initialize frames dict
|
||||||
if im_key not in frames:
|
if im_key not in frames:
|
||||||
|
@ -93,7 +93,7 @@ def render_dataset(offline_buffer, out_dir, max_num_samples, fps):
|
||||||
frames[im_key].append(ep_td["next"][im_key])
|
frames[im_key].append(ep_td["next"][im_key])
|
||||||
|
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
if len(offline_buffer.image_keys) > 1:
|
if len(dataset.image_keys) > 1:
|
||||||
camera = im_key[-1]
|
camera = im_key[-1]
|
||||||
video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4"
|
video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4"
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -879,6 +879,29 @@ files = [
|
||||||
[package.extras]
|
[package.extras]
|
||||||
protobuf = ["grpcio-tools (>=1.62.1)"]
|
protobuf = ["grpcio-tools (>=1.62.1)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "gym-pusht"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "PushT environment for LeRobot"
|
||||||
|
optional = true
|
||||||
|
python-versions = "^3.10"
|
||||||
|
files = []
|
||||||
|
develop = false
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
gymnasium = "^0.29.1"
|
||||||
|
opencv-python = "^4.9.0.80"
|
||||||
|
pygame = "^2.5.2"
|
||||||
|
pymunk = "^6.6.0"
|
||||||
|
scikit-image = "^0.22.0"
|
||||||
|
shapely = "^2.0.3"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "git"
|
||||||
|
url = "git@github.com:huggingface/gym-pusht.git"
|
||||||
|
reference = "HEAD"
|
||||||
|
resolved_reference = "d7e1a39a31b1368741e9674791007d7cccf046a3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gymnasium"
|
name = "gymnasium"
|
||||||
version = "0.29.1"
|
version = "0.29.1"
|
||||||
|
@ -3586,7 +3609,10 @@ files = [
|
||||||
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
|
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
|
||||||
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
|
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
|
||||||
|
|
||||||
|
[extras]
|
||||||
|
pusht = ["gym_pusht"]
|
||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "174c7d42f8039eedd2c447a4e6cae5169782cbd94346b5606572a0010194ca05"
|
content-hash = "3eee17e4bf2b7a570f41ef9c400ec5a24a3113f62a13162229cf43504ca0d005"
|
||||||
|
|
|
@ -52,7 +52,10 @@ robomimic = "0.2.0"
|
||||||
gymnasium-robotics = "^1.2.4"
|
gymnasium-robotics = "^1.2.4"
|
||||||
gymnasium = "^0.29.1"
|
gymnasium = "^0.29.1"
|
||||||
cmake = "^3.29.0.1"
|
cmake = "^3.29.0.1"
|
||||||
|
gym_pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
|
||||||
|
|
||||||
|
[tool.poetry.extras]
|
||||||
|
pusht = ["gym_pusht"]
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pre-commit = "^3.6.2"
|
pre-commit = "^3.6.2"
|
||||||
|
|
|
@ -6,6 +6,8 @@ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
from lerobot.common.utils import init_hydra_config
|
from lerobot.common.utils import init_hydra_config
|
||||||
|
import logging
|
||||||
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
|
||||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||||
|
|
||||||
|
@ -26,14 +28,29 @@ def test_factory(env_name, dataset_id):
|
||||||
DEFAULT_CONFIG_PATH,
|
DEFAULT_CONFIG_PATH,
|
||||||
overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"]
|
overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"]
|
||||||
)
|
)
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
dataset = make_dataset(cfg)
|
||||||
for key in offline_buffer.image_keys:
|
|
||||||
img = offline_buffer[0].get(key)
|
item = dataset[0]
|
||||||
|
|
||||||
|
assert "action" in item
|
||||||
|
assert "episode" in item
|
||||||
|
assert "frame_id" in item
|
||||||
|
assert "timestamp" in item
|
||||||
|
assert "next.done" in item
|
||||||
|
# TODO(rcadene): should we rename it agent_pos?
|
||||||
|
assert "observation.state" in item
|
||||||
|
for key in dataset.image_keys:
|
||||||
|
img = item.get(key)
|
||||||
assert img.dtype == torch.float32
|
assert img.dtype == torch.float32
|
||||||
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
||||||
assert img.max() <= 1.0
|
assert img.max() <= 1.0
|
||||||
assert img.min() >= 0.0
|
assert img.min() >= 0.0
|
||||||
|
|
||||||
|
if "next.reward" not in item:
|
||||||
|
logging.warning(f'Missing "next.reward" key in dataset {dataset}.')
|
||||||
|
if "next.done" not in item:
|
||||||
|
logging.warning(f'Missing "next.done" key in dataset {dataset}.')
|
||||||
|
|
||||||
|
|
||||||
def test_compute_stats():
|
def test_compute_stats():
|
||||||
"""Check that the statistics are computed correctly according to the stats_patterns property.
|
"""Check that the statistics are computed correctly according to the stats_patterns property.
|
||||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
import torch
|
import torch
|
||||||
from torchrl.envs.utils import check_env_specs, step_mdp
|
from torchrl.envs.utils import check_env_specs, step_mdp
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
|
||||||
from lerobot.common.envs.aloha.env import AlohaEnv
|
from lerobot.common.envs.aloha.env import AlohaEnv
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
|
@ -116,15 +116,15 @@ def test_factory(env_name):
|
||||||
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
||||||
)
|
)
|
||||||
|
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
dataset = make_dataset(cfg)
|
||||||
|
|
||||||
env = make_env(cfg)
|
env = make_env(cfg)
|
||||||
for key in offline_buffer.image_keys:
|
for key in dataset.image_keys:
|
||||||
assert env.reset().get(key).dtype == torch.uint8
|
assert env.reset().get(key).dtype == torch.uint8
|
||||||
check_env_specs(env)
|
check_env_specs(env)
|
||||||
|
|
||||||
env = make_env(cfg, transform=offline_buffer.transform)
|
env = make_env(cfg, transform=dataset.transform)
|
||||||
for key in offline_buffer.image_keys:
|
for key in dataset.image_keys:
|
||||||
img = env.reset().get(key)
|
img = env.reset().get(key)
|
||||||
assert img.dtype == torch.float32
|
assert img.dtype == torch.float32
|
||||||
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
||||||
|
|
|
@ -7,7 +7,7 @@ from torchrl.envs import EnvBase
|
||||||
|
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.policies.abstract import AbstractPolicy
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
from lerobot.common.utils import init_hydra_config
|
from lerobot.common.utils import init_hydra_config
|
||||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||||
|
@ -45,13 +45,13 @@ def test_concrete_policy(env_name, policy_name, extra_overrides):
|
||||||
# Check that we can make the policy object.
|
# Check that we can make the policy object.
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
# Check that we run select_actions and get the appropriate output.
|
# Check that we run select_actions and get the appropriate output.
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
dataset = make_dataset(cfg)
|
||||||
env = make_env(cfg, transform=offline_buffer.transform)
|
env = make_env(cfg, transform=dataset.transform)
|
||||||
|
|
||||||
if env_name != "aloha":
|
if env_name != "aloha":
|
||||||
# TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError:
|
# TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError:
|
||||||
# seq_length as a list is not supported for now.
|
# seq_length as a list is not supported for now.
|
||||||
policy.update(offline_buffer, torch.tensor(0, device=DEVICE))
|
policy.update(dataset, torch.tensor(0, device=DEVICE))
|
||||||
|
|
||||||
action = policy(
|
action = policy(
|
||||||
env.observation_spec.rand()["observation"].to(DEVICE),
|
env.observation_spec.rand()["observation"].to(DEVICE),
|
||||||
|
|
Loading…
Reference in New Issue