Merge pull request #7 from Cadene/user/rcadene/2024_03_05_abstract_replay_buffer
Add AbstractReplayBuffer
This commit is contained in:
commit
49c0955f97
|
@ -0,0 +1,158 @@
|
||||||
|
import abc
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
import torchrl
|
||||||
|
import tqdm
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from torchrl.data.datasets.utils import _get_root_dir
|
||||||
|
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||||
|
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||||
|
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||||
|
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
batch_size: int = None,
|
||||||
|
*,
|
||||||
|
shuffle: bool = True,
|
||||||
|
root: Path = None,
|
||||||
|
pin_memory: bool = False,
|
||||||
|
prefetch: int = None,
|
||||||
|
sampler: SliceSampler = None,
|
||||||
|
collate_fn: Callable = None,
|
||||||
|
writer: Writer = None,
|
||||||
|
transform: "torchrl.envs.Transform" = None,
|
||||||
|
):
|
||||||
|
self.dataset_id = dataset_id
|
||||||
|
self.shuffle = shuffle
|
||||||
|
self.root = _get_root_dir(self.dataset_id) if root is None else root
|
||||||
|
self.root = Path(self.root)
|
||||||
|
self.data_dir = self.root / self.dataset_id
|
||||||
|
|
||||||
|
storage = self._download_or_load_storage()
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
storage=storage,
|
||||||
|
sampler=sampler,
|
||||||
|
writer=ImmutableDatasetWriter() if writer is None else writer,
|
||||||
|
collate_fn=_collate_id if collate_fn is None else collate_fn,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
prefetch=prefetch,
|
||||||
|
batch_size=batch_size,
|
||||||
|
transform=transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stats_patterns(self) -> dict:
|
||||||
|
return {
|
||||||
|
("observation", "state"): "b c -> 1 c",
|
||||||
|
("observation", "image"): "b c h w -> 1 c 1 1",
|
||||||
|
("action"): "b c -> 1 c",
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_keys(self) -> list:
|
||||||
|
return [("observation", "image")]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_cameras(self) -> int:
|
||||||
|
return len(self.image_keys)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_samples(self) -> int:
|
||||||
|
return len(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_episodes(self) -> int:
|
||||||
|
return len(self._storage._storage["episode"].unique())
|
||||||
|
|
||||||
|
def set_transform(self, transform):
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
|
||||||
|
stats_path = self.data_dir / "stats.pth"
|
||||||
|
if stats_path.exists():
|
||||||
|
stats = torch.load(stats_path)
|
||||||
|
else:
|
||||||
|
logging.info(f"compute_stats and save to {stats_path}")
|
||||||
|
stats = self._compute_stats(num_batch, batch_size)
|
||||||
|
torch.save(stats, stats_path)
|
||||||
|
return stats
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _download_and_preproc(self) -> torch.StorageBase:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _download_or_load_storage(self):
|
||||||
|
if not self._is_downloaded():
|
||||||
|
storage = self._download_and_preproc()
|
||||||
|
else:
|
||||||
|
storage = TensorStorage(TensorDict.load_memmap(self.data_dir))
|
||||||
|
return storage
|
||||||
|
|
||||||
|
def _is_downloaded(self) -> bool:
|
||||||
|
return self.data_dir.is_dir()
|
||||||
|
|
||||||
|
def _compute_stats(self, num_batch=100, batch_size=32):
|
||||||
|
rb = TensorDictReplayBuffer(
|
||||||
|
storage=self._storage,
|
||||||
|
batch_size=batch_size,
|
||||||
|
prefetch=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mean, std, max, min = {}, {}, {}, {}
|
||||||
|
|
||||||
|
# compute mean, min, max
|
||||||
|
for _ in tqdm.tqdm(range(num_batch)):
|
||||||
|
batch = rb.sample()
|
||||||
|
for key, pattern in self.stats_patterns.items():
|
||||||
|
batch[key] = batch[key].float()
|
||||||
|
if key not in mean:
|
||||||
|
# first batch initialize mean, min, max
|
||||||
|
mean[key] = einops.reduce(batch[key], pattern, "mean")
|
||||||
|
max[key] = einops.reduce(batch[key], pattern, "max")
|
||||||
|
min[key] = einops.reduce(batch[key], pattern, "min")
|
||||||
|
else:
|
||||||
|
mean[key] += einops.reduce(batch[key], pattern, "mean")
|
||||||
|
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||||
|
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||||
|
batch = rb.sample()
|
||||||
|
|
||||||
|
for key in self.stats_patterns:
|
||||||
|
mean[key] /= num_batch
|
||||||
|
|
||||||
|
# compute std, min, max
|
||||||
|
for _ in tqdm.tqdm(range(num_batch)):
|
||||||
|
batch = rb.sample()
|
||||||
|
for key, pattern in self.stats_patterns.items():
|
||||||
|
batch[key] = batch[key].float()
|
||||||
|
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||||
|
if key not in std:
|
||||||
|
# first batch initialize std
|
||||||
|
std[key] = (batch_mean - mean[key]) ** 2
|
||||||
|
else:
|
||||||
|
std[key] += (batch_mean - mean[key]) ** 2
|
||||||
|
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||||
|
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||||
|
|
||||||
|
for key in self.stats_patterns:
|
||||||
|
std[key] = torch.sqrt(std[key] / num_batch)
|
||||||
|
|
||||||
|
stats = TensorDict({}, batch_size=[])
|
||||||
|
for key in self.stats_patterns:
|
||||||
|
stats[(*key, "mean")] = mean[key]
|
||||||
|
stats[(*key, "std")] = std[key]
|
||||||
|
stats[(*key, "max")] = max[key]
|
||||||
|
stats[(*key, "min")] = min[key]
|
||||||
|
|
||||||
|
if key[0] == "observation":
|
||||||
|
# use same stats for the next observations
|
||||||
|
stats[("next", *key)] = stats[key]
|
||||||
|
return stats
|
|
@ -5,33 +5,14 @@ from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
|
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
|
||||||
|
|
||||||
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
from lerobot.common.envs.transforms import NormalizeTransform
|
||||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
|
||||||
|
|
||||||
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
|
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
|
||||||
|
|
||||||
# TODO(rcadene): implement
|
|
||||||
|
|
||||||
# dataset_d4rl = D4RLExperienceReplay(
|
def make_offline_buffer(
|
||||||
# dataset_id="maze2d-umaze-v1",
|
cfg, overwrite_sampler=None, normalize=True, overwrite_batch_size=None, overwrite_prefetch=None
|
||||||
# split_trajs=False,
|
):
|
||||||
# batch_size=1,
|
|
||||||
# sampler=SamplerWithoutReplacement(drop_last=False),
|
|
||||||
# prefetch=4,
|
|
||||||
# direct_download=True,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# dataset_openx = OpenXExperienceReplay(
|
|
||||||
# "cmu_stretch",
|
|
||||||
# batch_size=1,
|
|
||||||
# num_slices=1,
|
|
||||||
# #download="force",
|
|
||||||
# streaming=False,
|
|
||||||
# root="data",
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
def make_offline_buffer(cfg, sampler=None):
|
|
||||||
if cfg.policy.balanced_sampling:
|
if cfg.policy.balanced_sampling:
|
||||||
assert cfg.online_steps > 0
|
assert cfg.online_steps > 0
|
||||||
batch_size = None
|
batch_size = None
|
||||||
|
@ -44,9 +25,13 @@ def make_offline_buffer(cfg, sampler=None):
|
||||||
pin_memory = cfg.device == "cuda"
|
pin_memory = cfg.device == "cuda"
|
||||||
prefetch = cfg.prefetch
|
prefetch = cfg.prefetch
|
||||||
|
|
||||||
overwrite_sampler = sampler is not None
|
if overwrite_batch_size is not None:
|
||||||
|
batch_size = overwrite_batch_size
|
||||||
|
|
||||||
if not overwrite_sampler:
|
if overwrite_prefetch is not None:
|
||||||
|
prefetch = overwrite_prefetch
|
||||||
|
|
||||||
|
if overwrite_sampler is None:
|
||||||
# TODO(rcadene): move batch_size outside
|
# TODO(rcadene): move batch_size outside
|
||||||
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
|
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
|
||||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||||
|
@ -67,36 +52,57 @@ def make_offline_buffer(cfg, sampler=None):
|
||||||
num_slices=num_traj_per_batch,
|
num_slices=num_traj_per_batch,
|
||||||
strict_length=False,
|
strict_length=False,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
sampler = overwrite_sampler
|
||||||
|
|
||||||
if cfg.env.name == "simxarm":
|
if cfg.env.name == "simxarm":
|
||||||
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
|
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||||
offline_buffer = SimxarmExperienceReplay(
|
|
||||||
f"xarm_{cfg.env.task}_medium",
|
clsfunc = SimxarmExperienceReplay
|
||||||
# download="force",
|
dataset_id = f"xarm_{cfg.env.task}_medium"
|
||||||
download=True,
|
|
||||||
streaming=False,
|
|
||||||
root=str(DATA_DIR),
|
|
||||||
sampler=sampler,
|
|
||||||
batch_size=batch_size,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
|
||||||
)
|
|
||||||
elif cfg.env.name == "pusht":
|
elif cfg.env.name == "pusht":
|
||||||
offline_buffer = PushtExperienceReplay(
|
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
||||||
"pusht",
|
|
||||||
streaming=False,
|
clsfunc = PushtExperienceReplay
|
||||||
root=DATA_DIR,
|
dataset_id = "pusht"
|
||||||
sampler=sampler,
|
|
||||||
batch_size=batch_size,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.env.name)
|
raise ValueError(cfg.env.name)
|
||||||
|
|
||||||
|
offline_buffer = clsfunc(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
root=DATA_DIR,
|
||||||
|
sampler=sampler,
|
||||||
|
batch_size=batch_size,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||||
|
stats = offline_buffer.compute_or_load_stats()
|
||||||
|
in_keys = [("observation", "state"), ("action")]
|
||||||
|
|
||||||
|
if cfg.policy == "tdmpc":
|
||||||
|
for key in offline_buffer.image_keys:
|
||||||
|
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
|
||||||
|
in_keys.append(key)
|
||||||
|
# since we use next observations in tdmpc
|
||||||
|
in_keys.append(("next", *key))
|
||||||
|
in_keys.append(("next", "observation", "state"))
|
||||||
|
|
||||||
|
if cfg.policy == "diffusion" and cfg.env.name == "pusht":
|
||||||
|
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||||
|
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||||
|
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||||
|
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||||
|
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||||
|
|
||||||
|
transform = NormalizeTransform(stats, in_keys, mode="min_max")
|
||||||
|
offline_buffer.set_transform(transform)
|
||||||
|
|
||||||
if not overwrite_sampler:
|
if not overwrite_sampler:
|
||||||
num_steps = len(offline_buffer)
|
index = torch.arange(0, offline_buffer.num_samples, 1)
|
||||||
index = torch.arange(0, num_steps, 1)
|
|
||||||
sampler.extend(index)
|
sampler.extend(index)
|
||||||
|
|
||||||
return offline_buffer
|
return offline_buffer
|
||||||
|
|
|
@ -1,6 +1,3 @@
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
@ -12,16 +9,14 @@ import torch
|
||||||
import torchrl
|
import torchrl
|
||||||
import tqdm
|
import tqdm
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchrl.data.datasets.utils import _get_root_dir
|
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||||
from torchrl.data.replay_buffers.samplers import Sampler
|
from torchrl.data.replay_buffers.writers import Writer
|
||||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
|
||||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
|
||||||
|
|
||||||
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||||
|
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||||
from lerobot.common.datasets.utils import download_and_extract_zip
|
from lerobot.common.datasets.utils import download_and_extract_zip
|
||||||
from lerobot.common.envs.transforms import NormalizeTransform
|
|
||||||
|
|
||||||
# as define in env
|
# as define in env
|
||||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||||
|
@ -87,114 +82,36 @@ def add_tee(
|
||||||
return body
|
return body
|
||||||
|
|
||||||
|
|
||||||
class PushtExperienceReplay(TensorDictReplayBuffer):
|
class PushtExperienceReplay(AbstractExperienceReplay):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
batch_size: int = None,
|
batch_size: int = None,
|
||||||
*,
|
*,
|
||||||
shuffle: bool = True,
|
shuffle: bool = True,
|
||||||
num_slices: int = None,
|
|
||||||
slice_len: int = None,
|
|
||||||
pad: float = None,
|
|
||||||
replacement: bool = None,
|
|
||||||
streaming: bool = False,
|
|
||||||
root: Path = None,
|
root: Path = None,
|
||||||
sampler: Sampler = None,
|
|
||||||
writer: Writer = None,
|
|
||||||
collate_fn: Callable = None,
|
|
||||||
pin_memory: bool = False,
|
pin_memory: bool = False,
|
||||||
prefetch: int = None,
|
prefetch: int = None,
|
||||||
transform: "torchrl.envs.Transform" = None, # noqa: F821
|
sampler: SliceSampler = None,
|
||||||
split_trajs: bool = False,
|
collate_fn: Callable = None,
|
||||||
strict_length: bool = True,
|
writer: Writer = None,
|
||||||
|
transform: "torchrl.envs.Transform" = None,
|
||||||
):
|
):
|
||||||
if streaming:
|
|
||||||
raise NotImplementedError
|
|
||||||
self.streaming = streaming
|
|
||||||
self.dataset_id = dataset_id
|
|
||||||
self.split_trajs = split_trajs
|
|
||||||
self.shuffle = shuffle
|
|
||||||
self.num_slices = num_slices
|
|
||||||
self.slice_len = slice_len
|
|
||||||
self.pad = pad
|
|
||||||
|
|
||||||
self.strict_length = strict_length
|
|
||||||
if (self.num_slices is not None) and (self.slice_len is not None):
|
|
||||||
raise ValueError("num_slices or slice_len can be not None, but not both.")
|
|
||||||
if split_trajs:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
if root is None:
|
|
||||||
root = _get_root_dir("pusht")
|
|
||||||
os.makedirs(root, exist_ok=True)
|
|
||||||
|
|
||||||
self.root = root
|
|
||||||
if not self._is_downloaded():
|
|
||||||
storage = self._download_and_preproc()
|
|
||||||
else:
|
|
||||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
|
||||||
|
|
||||||
stats = self._compute_or_load_stats(storage)
|
|
||||||
transform = NormalizeTransform(
|
|
||||||
stats,
|
|
||||||
in_keys=[
|
|
||||||
# TODO(rcadene): imagenet normalization is applied inside diffusion policy
|
|
||||||
# We need to automate this for tdmpc and others
|
|
||||||
# ("observation", "image"),
|
|
||||||
("observation", "state"),
|
|
||||||
# TODO(rcadene): for tdmpc, we might want next image and state
|
|
||||||
# ("next", "observation", "image"),
|
|
||||||
# ("next", "observation", "state"),
|
|
||||||
("action"),
|
|
||||||
],
|
|
||||||
mode="min_max",
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
|
||||||
transform.stats["observation", "state", "min"] = torch.tensor(
|
|
||||||
[13.456424, 32.938293], dtype=torch.float32
|
|
||||||
)
|
|
||||||
transform.stats["observation", "state", "max"] = torch.tensor(
|
|
||||||
[496.14618, 510.9579], dtype=torch.float32
|
|
||||||
)
|
|
||||||
transform.stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
|
||||||
transform.stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
|
||||||
|
|
||||||
if writer is None:
|
|
||||||
writer = ImmutableDatasetWriter()
|
|
||||||
if collate_fn is None:
|
|
||||||
collate_fn = _collate_id
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
storage=storage,
|
dataset_id,
|
||||||
sampler=sampler,
|
batch_size,
|
||||||
writer=writer,
|
shuffle=shuffle,
|
||||||
collate_fn=collate_fn,
|
root=root,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
prefetch=prefetch,
|
prefetch=prefetch,
|
||||||
batch_size=batch_size,
|
sampler=sampler,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
writer=writer,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def num_samples(self) -> int:
|
|
||||||
return len(self)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_episodes(self) -> int:
|
|
||||||
return len(self._storage._storage["episode"].unique())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def data_path_root(self) -> Path:
|
|
||||||
return None if self.streaming else self.root / self.dataset_id
|
|
||||||
|
|
||||||
def _is_downloaded(self) -> bool:
|
|
||||||
return self.data_path_root.is_dir()
|
|
||||||
|
|
||||||
def _download_and_preproc(self):
|
def _download_and_preproc(self):
|
||||||
# download
|
raw_dir = self.data_dir / "raw"
|
||||||
raw_dir = self.root / "raw"
|
|
||||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||||
if not zarr_path.is_dir():
|
if not zarr_path.is_dir():
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -266,8 +183,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
# last step of demonstration is considered done
|
# last step of demonstration is considered done
|
||||||
done[-1] = True
|
done[-1] = True
|
||||||
|
|
||||||
print("before " + """episode = TensorDict(""")
|
ep_td = TensorDict(
|
||||||
episode = TensorDict(
|
|
||||||
{
|
{
|
||||||
("observation", "image"): image[:-1],
|
("observation", "image"): image[:-1],
|
||||||
("observation", "state"): agent_pos[:-1],
|
("observation", "state"): agent_pos[:-1],
|
||||||
|
@ -286,120 +202,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
|
|
||||||
if episode_id == 0:
|
if episode_id == 0:
|
||||||
# hack to initialize tensordict data structure to store episodes
|
# hack to initialize tensordict data structure to store episodes
|
||||||
td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id)
|
td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir)
|
||||||
|
|
||||||
td_data[idxtd : idxtd + len(episode)] = episode
|
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||||
|
|
||||||
idx0 = idx1
|
idx0 = idx1
|
||||||
idxtd = idxtd + len(episode)
|
idxtd = idxtd + len(ep_td)
|
||||||
|
|
||||||
return TensorStorage(td_data.lock_())
|
return TensorStorage(td_data.lock_())
|
||||||
|
|
||||||
def _compute_stats(self, storage, num_batch=100, batch_size=32):
|
|
||||||
rb = TensorDictReplayBuffer(
|
|
||||||
storage=storage,
|
|
||||||
batch_size=batch_size,
|
|
||||||
prefetch=True,
|
|
||||||
)
|
|
||||||
batch = rb.sample()
|
|
||||||
|
|
||||||
image_channels = batch["observation", "image"].shape[1]
|
|
||||||
image_mean = torch.zeros(image_channels)
|
|
||||||
image_std = torch.zeros(image_channels)
|
|
||||||
image_max = torch.tensor([-math.inf] * image_channels)
|
|
||||||
image_min = torch.tensor([math.inf] * image_channels)
|
|
||||||
|
|
||||||
state_channels = batch["observation", "state"].shape[1]
|
|
||||||
state_mean = torch.zeros(state_channels)
|
|
||||||
state_std = torch.zeros(state_channels)
|
|
||||||
state_max = torch.tensor([-math.inf] * state_channels)
|
|
||||||
state_min = torch.tensor([math.inf] * state_channels)
|
|
||||||
|
|
||||||
action_channels = batch["action"].shape[1]
|
|
||||||
action_mean = torch.zeros(action_channels)
|
|
||||||
action_std = torch.zeros(action_channels)
|
|
||||||
action_max = torch.tensor([-math.inf] * action_channels)
|
|
||||||
action_min = torch.tensor([math.inf] * action_channels)
|
|
||||||
|
|
||||||
for _ in tqdm.tqdm(range(num_batch)):
|
|
||||||
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
|
||||||
state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
|
||||||
action_mean += einops.reduce(batch["action"], "b c -> c", "mean")
|
|
||||||
|
|
||||||
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
|
||||||
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
|
||||||
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
|
||||||
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
|
||||||
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
|
||||||
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
|
||||||
image_max = torch.maximum(image_max, b_image_max)
|
|
||||||
image_min = torch.maximum(image_min, b_image_min)
|
|
||||||
state_max = torch.maximum(state_max, b_state_max)
|
|
||||||
state_min = torch.maximum(state_min, b_state_min)
|
|
||||||
action_max = torch.maximum(action_max, b_action_max)
|
|
||||||
action_min = torch.maximum(action_min, b_action_min)
|
|
||||||
|
|
||||||
batch = rb.sample()
|
|
||||||
|
|
||||||
image_mean /= num_batch
|
|
||||||
state_mean /= num_batch
|
|
||||||
action_mean /= num_batch
|
|
||||||
|
|
||||||
for i in tqdm.tqdm(range(num_batch)):
|
|
||||||
b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
|
||||||
b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
|
||||||
b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean")
|
|
||||||
image_std += (b_image_mean - image_mean) ** 2
|
|
||||||
state_std += (b_state_mean - state_mean) ** 2
|
|
||||||
action_std += (b_action_mean - action_mean) ** 2
|
|
||||||
|
|
||||||
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
|
||||||
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
|
||||||
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
|
||||||
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
|
||||||
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
|
||||||
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
|
||||||
image_max = torch.maximum(image_max, b_image_max)
|
|
||||||
image_min = torch.maximum(image_min, b_image_min)
|
|
||||||
state_max = torch.maximum(state_max, b_state_max)
|
|
||||||
state_min = torch.maximum(state_min, b_state_min)
|
|
||||||
action_max = torch.maximum(action_max, b_action_max)
|
|
||||||
action_min = torch.maximum(action_min, b_action_min)
|
|
||||||
|
|
||||||
if i < num_batch - 1:
|
|
||||||
batch = rb.sample()
|
|
||||||
|
|
||||||
image_std = torch.sqrt(image_std / num_batch)
|
|
||||||
state_std = torch.sqrt(state_std / num_batch)
|
|
||||||
action_std = torch.sqrt(action_std / num_batch)
|
|
||||||
|
|
||||||
stats = TensorDict(
|
|
||||||
{
|
|
||||||
("observation", "image", "mean"): image_mean[None, :, None, None],
|
|
||||||
("observation", "image", "std"): image_std[None, :, None, None],
|
|
||||||
("observation", "image", "max"): image_max[None, :, None, None],
|
|
||||||
("observation", "image", "min"): image_min[None, :, None, None],
|
|
||||||
("observation", "state", "mean"): state_mean[None, :],
|
|
||||||
("observation", "state", "std"): state_std[None, :],
|
|
||||||
("observation", "state", "max"): state_max[None, :],
|
|
||||||
("observation", "state", "min"): state_min[None, :],
|
|
||||||
("action", "mean"): action_mean[None, :],
|
|
||||||
("action", "std"): action_std[None, :],
|
|
||||||
("action", "max"): action_max[None, :],
|
|
||||||
("action", "min"): action_min[None, :],
|
|
||||||
},
|
|
||||||
batch_size=[],
|
|
||||||
)
|
|
||||||
stats["next", "observation", "image"] = stats["observation", "image"]
|
|
||||||
stats["next", "observation", "state"] = stats["observation", "state"]
|
|
||||||
return stats
|
|
||||||
|
|
||||||
def _compute_or_load_stats(self, storage) -> TensorDict:
|
|
||||||
stats_path = self.root / self.dataset_id / "stats.pth"
|
|
||||||
if stats_path.exists():
|
|
||||||
stats = torch.load(stats_path)
|
|
||||||
else:
|
|
||||||
logging.info(f"compute_stats and save to {stats_path}")
|
|
||||||
stats = self._compute_stats(storage)
|
|
||||||
torch.save(stats, stats_path)
|
|
||||||
return stats
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
import pickle
|
import pickle
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
@ -7,130 +6,52 @@ import torch
|
||||||
import torchrl
|
import torchrl
|
||||||
import tqdm
|
import tqdm
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchrl.data.datasets.utils import _get_root_dir
|
|
||||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
|
||||||
from torchrl.data.replay_buffers.samplers import (
|
from torchrl.data.replay_buffers.samplers import (
|
||||||
Sampler,
|
|
||||||
SliceSampler,
|
SliceSampler,
|
||||||
SliceSamplerWithoutReplacement,
|
|
||||||
)
|
)
|
||||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
from torchrl.data.replay_buffers.writers import Writer
|
||||||
|
|
||||||
|
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||||
|
|
||||||
|
|
||||||
class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
class SimxarmExperienceReplay(AbstractExperienceReplay):
|
||||||
available_datasets = [
|
available_datasets = [
|
||||||
"xarm_lift_medium",
|
"xarm_lift_medium",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_id,
|
dataset_id: str,
|
||||||
batch_size: int = None,
|
batch_size: int = None,
|
||||||
*,
|
*,
|
||||||
shuffle: bool = True,
|
shuffle: bool = True,
|
||||||
num_slices: int = None,
|
|
||||||
slice_len: int = None,
|
|
||||||
pad: float = None,
|
|
||||||
replacement: bool = None,
|
|
||||||
streaming: bool = False,
|
|
||||||
root: Path = None,
|
root: Path = None,
|
||||||
download: bool = False,
|
|
||||||
sampler: Sampler = None,
|
|
||||||
writer: Writer = None,
|
|
||||||
collate_fn: Callable = None,
|
|
||||||
pin_memory: bool = False,
|
pin_memory: bool = False,
|
||||||
prefetch: int = None,
|
prefetch: int = None,
|
||||||
transform: "torchrl.envs.Transform" = None, # noqa-F821
|
sampler: SliceSampler = None,
|
||||||
split_trajs: bool = False,
|
collate_fn: Callable = None,
|
||||||
strict_length: bool = True,
|
writer: Writer = None,
|
||||||
|
transform: "torchrl.envs.Transform" = None,
|
||||||
):
|
):
|
||||||
self.download = download
|
|
||||||
if streaming:
|
|
||||||
raise NotImplementedError
|
|
||||||
self.streaming = streaming
|
|
||||||
self.dataset_id = dataset_id
|
|
||||||
self.split_trajs = split_trajs
|
|
||||||
self.shuffle = shuffle
|
|
||||||
self.num_slices = num_slices
|
|
||||||
self.slice_len = slice_len
|
|
||||||
self.pad = pad
|
|
||||||
|
|
||||||
self.strict_length = strict_length
|
|
||||||
if (self.num_slices is not None) and (self.slice_len is not None):
|
|
||||||
raise ValueError("num_slices or slice_len can be not None, but not both.")
|
|
||||||
if split_trajs:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
if root is None:
|
|
||||||
root = _get_root_dir("simxarm")
|
|
||||||
os.makedirs(root, exist_ok=True)
|
|
||||||
self.root = Path(root)
|
|
||||||
if self.download == "force" or (self.download and not self._is_downloaded()):
|
|
||||||
storage = self._download_and_preproc()
|
|
||||||
else:
|
|
||||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
|
||||||
|
|
||||||
if num_slices is not None or slice_len is not None:
|
|
||||||
if sampler is not None:
|
|
||||||
raise ValueError("`num_slices` and `slice_len` are exclusive with the `sampler` argument.")
|
|
||||||
|
|
||||||
if replacement:
|
|
||||||
if not self.shuffle:
|
|
||||||
raise RuntimeError("shuffle=False can only be used when replacement=False.")
|
|
||||||
sampler = SliceSampler(
|
|
||||||
num_slices=num_slices,
|
|
||||||
slice_len=slice_len,
|
|
||||||
strict_length=strict_length,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sampler = SliceSamplerWithoutReplacement(
|
|
||||||
num_slices=num_slices,
|
|
||||||
slice_len=slice_len,
|
|
||||||
strict_length=strict_length,
|
|
||||||
shuffle=self.shuffle,
|
|
||||||
)
|
|
||||||
|
|
||||||
if writer is None:
|
|
||||||
writer = ImmutableDatasetWriter()
|
|
||||||
if collate_fn is None:
|
|
||||||
collate_fn = _collate_id
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
storage=storage,
|
dataset_id,
|
||||||
sampler=sampler,
|
batch_size,
|
||||||
writer=writer,
|
shuffle=shuffle,
|
||||||
collate_fn=collate_fn,
|
root=root,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
prefetch=prefetch,
|
prefetch=prefetch,
|
||||||
batch_size=batch_size,
|
sampler=sampler,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
writer=writer,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def num_samples(self):
|
|
||||||
return len(self)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_episodes(self):
|
|
||||||
return len(self._storage._storage["episode"].unique())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def data_path_root(self):
|
|
||||||
if self.streaming:
|
|
||||||
return None
|
|
||||||
return self.root / self.dataset_id
|
|
||||||
|
|
||||||
def _is_downloaded(self):
|
|
||||||
return os.path.exists(self.data_path_root)
|
|
||||||
|
|
||||||
def _download_and_preproc(self):
|
def _download_and_preproc(self):
|
||||||
# download
|
# download
|
||||||
# TODO(rcadene)
|
# TODO(rcadene)
|
||||||
|
|
||||||
# load
|
dataset_path = self.data_dir / "buffer.pkl"
|
||||||
dataset_dir = Path("data") / self.dataset_id
|
|
||||||
dataset_path = dataset_dir / "buffer.pkl"
|
|
||||||
print(f"Using offline dataset '{dataset_path}'")
|
print(f"Using offline dataset '{dataset_path}'")
|
||||||
with open(dataset_path, "rb") as f:
|
with open(dataset_path, "rb") as f:
|
||||||
dataset_dict = pickle.load(f)
|
dataset_dict = pickle.load(f)
|
||||||
|
@ -172,7 +93,7 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
||||||
|
|
||||||
if episode_id == 0:
|
if episode_id == 0:
|
||||||
# hack to initialize tensordict data structure to store episodes
|
# hack to initialize tensordict data structure to store episodes
|
||||||
td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id)
|
td_data = episode[0].expand(total_frames).memmap_like(self.data_dir)
|
||||||
|
|
||||||
td_data[idx0:idx1] = episode
|
td_data[idx0:idx1] = episode
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,10 @@ from omegaconf import OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
|
||||||
|
def log_output_dir(out_dir):
|
||||||
|
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||||
|
|
||||||
|
|
||||||
def cfg_to_group(cfg, return_list=False):
|
def cfg_to_group(cfg, return_list=False):
|
||||||
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
|
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
|
||||||
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
|
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
|
||||||
|
|
|
@ -9,13 +9,13 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from tensordict.nn import TensorDictModule
|
from tensordict.nn import TensorDictModule
|
||||||
from termcolor import colored
|
|
||||||
from torchrl.envs import EnvBase
|
from torchrl.envs import EnvBase
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
|
from lerobot.common.logger import log_output_dir
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils import set_seed
|
from lerobot.common.utils import init_logging, set_seed
|
||||||
|
|
||||||
|
|
||||||
def write_video(video_path, stacked_frames, fps):
|
def write_video(video_path, stacked_frames, fps):
|
||||||
|
@ -109,10 +109,18 @@ def eval(cfg: dict, out_dir=None):
|
||||||
if out_dir is None:
|
if out_dir is None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
assert torch.cuda.is_available()
|
init_logging()
|
||||||
|
|
||||||
|
if cfg.device == "cuda":
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
else:
|
||||||
|
logging.warning("Using CPU, this will be slow.")
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir)
|
|
||||||
|
log_output_dir(out_dir)
|
||||||
|
|
||||||
logging.info("make_offline_buffer")
|
logging.info("make_offline_buffer")
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
|
@ -142,6 +150,8 @@ def eval(cfg: dict, out_dir=None):
|
||||||
)
|
)
|
||||||
print(metrics)
|
print(metrics)
|
||||||
|
|
||||||
|
logging.info("End of eval")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
eval_cli()
|
eval_cli()
|
||||||
|
|
|
@ -4,13 +4,12 @@ import hydra
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tensordict.nn import TensorDictModule
|
from tensordict.nn import TensorDictModule
|
||||||
from termcolor import colored
|
|
||||||
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
||||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import Logger
|
from lerobot.common.logger import Logger, log_output_dir
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils import format_big_number, init_logging, set_seed
|
from lerobot.common.utils import format_big_number, init_logging, set_seed
|
||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
|
@ -164,7 +163,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
# log metrics to terminal and wandb
|
# log metrics to terminal and wandb
|
||||||
logger = Logger(out_dir, job_name, cfg)
|
logger = Logger(out_dir, job_name, cfg)
|
||||||
|
|
||||||
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
log_output_dir(out_dir)
|
||||||
logging.info(f"{cfg.env.task=}")
|
logging.info(f"{cfg.env.task=}")
|
||||||
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
|
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
|
||||||
logging.info(f"{cfg.online_steps=}")
|
logging.info(f"{cfg.online_steps=}")
|
||||||
|
@ -212,7 +211,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
for env_step in range(cfg.online_steps):
|
for env_step in range(cfg.online_steps):
|
||||||
if env_step == 0:
|
if env_step == 0:
|
||||||
logging.info("Start online training by interacting with environment")
|
logging.info("Start online training by interacting with environment")
|
||||||
# TODO: use SyncDataCollector for that?
|
|
||||||
# TODO: add configurable number of rollout? (default=1)
|
# TODO: add configurable number of rollout? (default=1)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
rollout = env.rollout(
|
rollout = env.rollout(
|
||||||
|
@ -268,6 +266,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
step += 1
|
step += 1
|
||||||
online_step += 1
|
online_step += 1
|
||||||
|
|
||||||
|
logging.info("End of training")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train_cli()
|
train_cli()
|
||||||
|
|
|
@ -1,13 +1,20 @@
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import einops
|
||||||
import hydra
|
import hydra
|
||||||
import imageio
|
import imageio
|
||||||
import torch
|
import torch
|
||||||
from torchrl.data.replay_buffers import SliceSamplerWithoutReplacement
|
from torchrl.data.replay_buffers import (
|
||||||
|
SamplerWithoutReplacement,
|
||||||
|
)
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_offline_buffer
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
|
from lerobot.common.logger import log_output_dir
|
||||||
|
from lerobot.common.utils import init_logging
|
||||||
|
|
||||||
NUM_EPISODES_TO_RENDER = 10
|
NUM_EPISODES_TO_RENDER = 50
|
||||||
MAX_NUM_STEPS = 1000
|
MAX_NUM_STEPS = 1000
|
||||||
FIRST_FRAME = 0
|
FIRST_FRAME = 0
|
||||||
|
|
||||||
|
@ -17,45 +24,88 @@ def visualize_dataset_cli(cfg: dict):
|
||||||
visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
|
visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def cat_and_write_video(video_path, frames, fps):
|
||||||
|
frames = torch.cat(frames)
|
||||||
|
assert frames.dtype == torch.uint8
|
||||||
|
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
|
||||||
|
imageio.mimsave(video_path, frames, fps=fps)
|
||||||
|
|
||||||
|
|
||||||
def visualize_dataset(cfg: dict, out_dir=None):
|
def visualize_dataset(cfg: dict, out_dir=None):
|
||||||
if out_dir is None:
|
if out_dir is None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
sampler = SliceSamplerWithoutReplacement(
|
init_logging()
|
||||||
num_slices=1,
|
log_output_dir(out_dir)
|
||||||
strict_length=False,
|
|
||||||
|
# we expect frames of each episode to be stored next to each others sequentially
|
||||||
|
sampler = SamplerWithoutReplacement(
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
offline_buffer = make_offline_buffer(cfg, sampler)
|
logging.info("make_offline_buffer")
|
||||||
|
offline_buffer = make_offline_buffer(
|
||||||
|
cfg, overwrite_sampler=sampler, normalize=False, overwrite_batch_size=1, overwrite_prefetch=12
|
||||||
|
)
|
||||||
|
|
||||||
for _ in range(NUM_EPISODES_TO_RENDER):
|
logging.info("Start rendering episodes from offline buffer")
|
||||||
episode = offline_buffer.sample(MAX_NUM_STEPS)
|
|
||||||
|
|
||||||
ep_idx = episode["episode"][FIRST_FRAME].item()
|
threads = []
|
||||||
ep_frames = torch.cat(
|
frames = {}
|
||||||
[
|
current_ep_idx = 0
|
||||||
episode["observation"]["image"][FIRST_FRAME][None, ...],
|
logging.info(f"Visualizing episode {current_ep_idx}")
|
||||||
episode["next", "observation"]["image"],
|
for _ in range(MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER):
|
||||||
],
|
# TODO(rcadene): make it work with bsize > 1
|
||||||
dim=0,
|
ep_td = offline_buffer.sample(1)
|
||||||
)
|
ep_idx = ep_td["episode"][FIRST_FRAME].item()
|
||||||
|
|
||||||
video_dir = Path(out_dir) / "visualize_dataset"
|
# TODO(rcaene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
|
||||||
video_dir.mkdir(parents=True, exist_ok=True)
|
no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
|
||||||
# TODO(rcadene): make fps configurable
|
new_episode = ep_idx != current_ep_idx
|
||||||
video_path = video_dir / f"episode_{ep_idx}.mp4"
|
|
||||||
|
|
||||||
assert ep_frames.min().item() >= 0
|
if new_episode:
|
||||||
assert ep_frames.max().item() > 1, "Not mendatory, but sanity check"
|
logging.info(f"Visualizing episode {current_ep_idx}")
|
||||||
assert ep_frames.max().item() <= 255
|
|
||||||
ep_frames = ep_frames.type(torch.uint8)
|
|
||||||
imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps)
|
|
||||||
|
|
||||||
# ran out of episodes
|
for im_key in offline_buffer.image_keys:
|
||||||
if offline_buffer._sampler._sample_list.numel() == 0:
|
if new_episode or no_more_frames:
|
||||||
|
# append last observed frames (the ones after last action taken)
|
||||||
|
frames[im_key].append(ep_td[("next", *im_key)])
|
||||||
|
|
||||||
|
video_dir = Path(out_dir) / "visualize_dataset"
|
||||||
|
video_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if len(offline_buffer.image_keys) > 1:
|
||||||
|
camera = im_key[-1]
|
||||||
|
video_path = video_dir / f"episode_{current_ep_idx}_{camera}.mp4"
|
||||||
|
else:
|
||||||
|
video_path = video_dir / f"episode_{current_ep_idx}.mp4"
|
||||||
|
|
||||||
|
thread = threading.Thread(
|
||||||
|
target=cat_and_write_video,
|
||||||
|
args=(str(video_path), frames[im_key], cfg.fps),
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
|
threads.append(thread)
|
||||||
|
|
||||||
|
current_ep_idx = ep_idx
|
||||||
|
|
||||||
|
# reset list of frames
|
||||||
|
del frames[im_key]
|
||||||
|
|
||||||
|
# append current cameras images to list of frames
|
||||||
|
if im_key not in frames:
|
||||||
|
frames[im_key] = []
|
||||||
|
frames[im_key].append(ep_td[im_key])
|
||||||
|
|
||||||
|
if no_more_frames:
|
||||||
|
logging.info("Ran out of frames")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
logging.info("End of visualize_dataset")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
visualize_dataset_cli()
|
visualize_dataset_cli()
|
||||||
|
|
Loading…
Reference in New Issue