Merge pull request #7 from Cadene/user/rcadene/2024_03_05_abstract_replay_buffer

Add AbstractReplayBuffer
This commit is contained in:
Remi 2024-03-06 11:25:24 +01:00 committed by GitHub
commit 49c0955f97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 351 additions and 395 deletions

View File

@ -0,0 +1,158 @@
import abc
import logging
from pathlib import Path
from typing import Callable
import einops
import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
class AbstractExperienceReplay(TensorDictReplayBuffer):
def __init__(
self,
dataset_id: str,
batch_size: int = None,
*,
shuffle: bool = True,
root: Path = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
collate_fn: Callable = None,
writer: Writer = None,
transform: "torchrl.envs.Transform" = None,
):
self.dataset_id = dataset_id
self.shuffle = shuffle
self.root = _get_root_dir(self.dataset_id) if root is None else root
self.root = Path(self.root)
self.data_dir = self.root / self.dataset_id
storage = self._download_or_load_storage()
super().__init__(
storage=storage,
sampler=sampler,
writer=ImmutableDatasetWriter() if writer is None else writer,
collate_fn=_collate_id if collate_fn is None else collate_fn,
pin_memory=pin_memory,
prefetch=prefetch,
batch_size=batch_size,
transform=transform,
)
@property
def stats_patterns(self) -> dict:
return {
("observation", "state"): "b c -> 1 c",
("observation", "image"): "b c h w -> 1 c 1 1",
("action"): "b c -> 1 c",
}
@property
def image_keys(self) -> list:
return [("observation", "image")]
@property
def num_cameras(self) -> int:
return len(self.image_keys)
@property
def num_samples(self) -> int:
return len(self)
@property
def num_episodes(self) -> int:
return len(self._storage._storage["episode"].unique())
def set_transform(self, transform):
self.transform = transform
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
stats_path = self.data_dir / "stats.pth"
if stats_path.exists():
stats = torch.load(stats_path)
else:
logging.info(f"compute_stats and save to {stats_path}")
stats = self._compute_stats(num_batch, batch_size)
torch.save(stats, stats_path)
return stats
@abc.abstractmethod
def _download_and_preproc(self) -> torch.StorageBase:
raise NotImplementedError()
def _download_or_load_storage(self):
if not self._is_downloaded():
storage = self._download_and_preproc()
else:
storage = TensorStorage(TensorDict.load_memmap(self.data_dir))
return storage
def _is_downloaded(self) -> bool:
return self.data_dir.is_dir()
def _compute_stats(self, num_batch=100, batch_size=32):
rb = TensorDictReplayBuffer(
storage=self._storage,
batch_size=batch_size,
prefetch=True,
)
mean, std, max, min = {}, {}, {}, {}
# compute mean, min, max
for _ in tqdm.tqdm(range(num_batch)):
batch = rb.sample()
for key, pattern in self.stats_patterns.items():
batch[key] = batch[key].float()
if key not in mean:
# first batch initialize mean, min, max
mean[key] = einops.reduce(batch[key], pattern, "mean")
max[key] = einops.reduce(batch[key], pattern, "max")
min[key] = einops.reduce(batch[key], pattern, "min")
else:
mean[key] += einops.reduce(batch[key], pattern, "mean")
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
batch = rb.sample()
for key in self.stats_patterns:
mean[key] /= num_batch
# compute std, min, max
for _ in tqdm.tqdm(range(num_batch)):
batch = rb.sample()
for key, pattern in self.stats_patterns.items():
batch[key] = batch[key].float()
batch_mean = einops.reduce(batch[key], pattern, "mean")
if key not in std:
# first batch initialize std
std[key] = (batch_mean - mean[key]) ** 2
else:
std[key] += (batch_mean - mean[key]) ** 2
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
for key in self.stats_patterns:
std[key] = torch.sqrt(std[key] / num_batch)
stats = TensorDict({}, batch_size=[])
for key in self.stats_patterns:
stats[(*key, "mean")] = mean[key]
stats[(*key, "std")] = std[key]
stats[(*key, "max")] = max[key]
stats[(*key, "min")] = min[key]
if key[0] == "observation":
# use same stats for the next observations
stats[("next", *key)] = stats[key]
return stats

View File

@ -5,33 +5,14 @@ from pathlib import Path
import torch import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
from lerobot.common.datasets.pusht import PushtExperienceReplay from lerobot.common.envs.transforms import NormalizeTransform
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
DATA_DIR = Path(os.environ.get("DATA_DIR", "data")) DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
# TODO(rcadene): implement
# dataset_d4rl = D4RLExperienceReplay( def make_offline_buffer(
# dataset_id="maze2d-umaze-v1", cfg, overwrite_sampler=None, normalize=True, overwrite_batch_size=None, overwrite_prefetch=None
# split_trajs=False, ):
# batch_size=1,
# sampler=SamplerWithoutReplacement(drop_last=False),
# prefetch=4,
# direct_download=True,
# )
# dataset_openx = OpenXExperienceReplay(
# "cmu_stretch",
# batch_size=1,
# num_slices=1,
# #download="force",
# streaming=False,
# root="data",
# )
def make_offline_buffer(cfg, sampler=None):
if cfg.policy.balanced_sampling: if cfg.policy.balanced_sampling:
assert cfg.online_steps > 0 assert cfg.online_steps > 0
batch_size = None batch_size = None
@ -44,9 +25,13 @@ def make_offline_buffer(cfg, sampler=None):
pin_memory = cfg.device == "cuda" pin_memory = cfg.device == "cuda"
prefetch = cfg.prefetch prefetch = cfg.prefetch
overwrite_sampler = sampler is not None if overwrite_batch_size is not None:
batch_size = overwrite_batch_size
if not overwrite_sampler: if overwrite_prefetch is not None:
prefetch = overwrite_prefetch
if overwrite_sampler is None:
# TODO(rcadene): move batch_size outside # TODO(rcadene): move batch_size outside
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size. # TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
@ -67,36 +52,57 @@ def make_offline_buffer(cfg, sampler=None):
num_slices=num_traj_per_batch, num_slices=num_traj_per_batch,
strict_length=False, strict_length=False,
) )
else:
sampler = overwrite_sampler
if cfg.env.name == "simxarm": if cfg.env.name == "simxarm":
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
offline_buffer = SimxarmExperienceReplay(
f"xarm_{cfg.env.task}_medium", clsfunc = SimxarmExperienceReplay
# download="force", dataset_id = f"xarm_{cfg.env.task}_medium"
download=True,
streaming=False,
root=str(DATA_DIR),
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
prefetch=prefetch if isinstance(prefetch, int) else None,
)
elif cfg.env.name == "pusht": elif cfg.env.name == "pusht":
offline_buffer = PushtExperienceReplay( from lerobot.common.datasets.pusht import PushtExperienceReplay
"pusht",
streaming=False, clsfunc = PushtExperienceReplay
root=DATA_DIR, dataset_id = "pusht"
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
prefetch=prefetch if isinstance(prefetch, int) else None,
)
else: else:
raise ValueError(cfg.env.name) raise ValueError(cfg.env.name)
offline_buffer = clsfunc(
dataset_id=dataset_id,
root=DATA_DIR,
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
prefetch=prefetch if isinstance(prefetch, int) else None,
)
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
stats = offline_buffer.compute_or_load_stats()
in_keys = [("observation", "state"), ("action")]
if cfg.policy == "tdmpc":
for key in offline_buffer.image_keys:
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
in_keys.append(key)
# since we use next observations in tdmpc
in_keys.append(("next", *key))
in_keys.append(("next", "observation", "state"))
if cfg.policy == "diffusion" and cfg.env.name == "pusht":
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
transform = NormalizeTransform(stats, in_keys, mode="min_max")
offline_buffer.set_transform(transform)
if not overwrite_sampler: if not overwrite_sampler:
num_steps = len(offline_buffer) index = torch.arange(0, offline_buffer.num_samples, 1)
index = torch.arange(0, num_steps, 1)
sampler.extend(index) sampler.extend(index)
return offline_buffer return offline_buffer

View File

@ -1,6 +1,3 @@
import logging
import math
import os
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable
@ -12,16 +9,14 @@ import torch
import torchrl import torchrl
import tqdm import tqdm
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.samplers import Sampler from torchrl.data.replay_buffers.writers import Writer
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from lerobot.common.datasets.abstract import AbstractExperienceReplay
from lerobot.common.datasets.utils import download_and_extract_zip from lerobot.common.datasets.utils import download_and_extract_zip
from lerobot.common.envs.transforms import NormalizeTransform
# as define in env # as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage, SUCCESS_THRESHOLD = 0.95 # 95% coverage,
@ -87,114 +82,36 @@ def add_tee(
return body return body
class PushtExperienceReplay(TensorDictReplayBuffer): class PushtExperienceReplay(AbstractExperienceReplay):
def __init__( def __init__(
self, self,
dataset_id: str, dataset_id: str,
batch_size: int = None, batch_size: int = None,
*, *,
shuffle: bool = True, shuffle: bool = True,
num_slices: int = None,
slice_len: int = None,
pad: float = None,
replacement: bool = None,
streaming: bool = False,
root: Path = None, root: Path = None,
sampler: Sampler = None,
writer: Writer = None,
collate_fn: Callable = None,
pin_memory: bool = False, pin_memory: bool = False,
prefetch: int = None, prefetch: int = None,
transform: "torchrl.envs.Transform" = None, # noqa: F821 sampler: SliceSampler = None,
split_trajs: bool = False, collate_fn: Callable = None,
strict_length: bool = True, writer: Writer = None,
transform: "torchrl.envs.Transform" = None,
): ):
if streaming:
raise NotImplementedError
self.streaming = streaming
self.dataset_id = dataset_id
self.split_trajs = split_trajs
self.shuffle = shuffle
self.num_slices = num_slices
self.slice_len = slice_len
self.pad = pad
self.strict_length = strict_length
if (self.num_slices is not None) and (self.slice_len is not None):
raise ValueError("num_slices or slice_len can be not None, but not both.")
if split_trajs:
raise NotImplementedError
if root is None:
root = _get_root_dir("pusht")
os.makedirs(root, exist_ok=True)
self.root = root
if not self._is_downloaded():
storage = self._download_and_preproc()
else:
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
stats = self._compute_or_load_stats(storage)
transform = NormalizeTransform(
stats,
in_keys=[
# TODO(rcadene): imagenet normalization is applied inside diffusion policy
# We need to automate this for tdmpc and others
# ("observation", "image"),
("observation", "state"),
# TODO(rcadene): for tdmpc, we might want next image and state
# ("next", "observation", "image"),
# ("next", "observation", "state"),
("action"),
],
mode="min_max",
)
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
transform.stats["observation", "state", "min"] = torch.tensor(
[13.456424, 32.938293], dtype=torch.float32
)
transform.stats["observation", "state", "max"] = torch.tensor(
[496.14618, 510.9579], dtype=torch.float32
)
transform.stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
transform.stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
if writer is None:
writer = ImmutableDatasetWriter()
if collate_fn is None:
collate_fn = _collate_id
super().__init__( super().__init__(
storage=storage, dataset_id,
sampler=sampler, batch_size,
writer=writer, shuffle=shuffle,
collate_fn=collate_fn, root=root,
pin_memory=pin_memory, pin_memory=pin_memory,
prefetch=prefetch, prefetch=prefetch,
batch_size=batch_size, sampler=sampler,
collate_fn=collate_fn,
writer=writer,
transform=transform, transform=transform,
) )
@property
def num_samples(self) -> int:
return len(self)
@property
def num_episodes(self) -> int:
return len(self._storage._storage["episode"].unique())
@property
def data_path_root(self) -> Path:
return None if self.streaming else self.root / self.dataset_id
def _is_downloaded(self) -> bool:
return self.data_path_root.is_dir()
def _download_and_preproc(self): def _download_and_preproc(self):
# download raw_dir = self.data_dir / "raw"
raw_dir = self.root / "raw"
zarr_path = (raw_dir / PUSHT_ZARR).resolve() zarr_path = (raw_dir / PUSHT_ZARR).resolve()
if not zarr_path.is_dir(): if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True) raw_dir.mkdir(parents=True, exist_ok=True)
@ -266,8 +183,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
# last step of demonstration is considered done # last step of demonstration is considered done
done[-1] = True done[-1] = True
print("before " + """episode = TensorDict(""") ep_td = TensorDict(
episode = TensorDict(
{ {
("observation", "image"): image[:-1], ("observation", "image"): image[:-1],
("observation", "state"): agent_pos[:-1], ("observation", "state"): agent_pos[:-1],
@ -286,120 +202,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
if episode_id == 0: if episode_id == 0:
# hack to initialize tensordict data structure to store episodes # hack to initialize tensordict data structure to store episodes
td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id) td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir)
td_data[idxtd : idxtd + len(episode)] = episode td_data[idxtd : idxtd + len(ep_td)] = ep_td
idx0 = idx1 idx0 = idx1
idxtd = idxtd + len(episode) idxtd = idxtd + len(ep_td)
return TensorStorage(td_data.lock_()) return TensorStorage(td_data.lock_())
def _compute_stats(self, storage, num_batch=100, batch_size=32):
rb = TensorDictReplayBuffer(
storage=storage,
batch_size=batch_size,
prefetch=True,
)
batch = rb.sample()
image_channels = batch["observation", "image"].shape[1]
image_mean = torch.zeros(image_channels)
image_std = torch.zeros(image_channels)
image_max = torch.tensor([-math.inf] * image_channels)
image_min = torch.tensor([math.inf] * image_channels)
state_channels = batch["observation", "state"].shape[1]
state_mean = torch.zeros(state_channels)
state_std = torch.zeros(state_channels)
state_max = torch.tensor([-math.inf] * state_channels)
state_min = torch.tensor([math.inf] * state_channels)
action_channels = batch["action"].shape[1]
action_mean = torch.zeros(action_channels)
action_std = torch.zeros(action_channels)
action_max = torch.tensor([-math.inf] * action_channels)
action_min = torch.tensor([math.inf] * action_channels)
for _ in tqdm.tqdm(range(num_batch)):
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean")
action_mean += einops.reduce(batch["action"], "b c -> c", "mean")
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
image_max = torch.maximum(image_max, b_image_max)
image_min = torch.maximum(image_min, b_image_min)
state_max = torch.maximum(state_max, b_state_max)
state_min = torch.maximum(state_min, b_state_min)
action_max = torch.maximum(action_max, b_action_max)
action_min = torch.maximum(action_min, b_action_min)
batch = rb.sample()
image_mean /= num_batch
state_mean /= num_batch
action_mean /= num_batch
for i in tqdm.tqdm(range(num_batch)):
b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean")
b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean")
image_std += (b_image_mean - image_mean) ** 2
state_std += (b_state_mean - state_mean) ** 2
action_std += (b_action_mean - action_mean) ** 2
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
image_max = torch.maximum(image_max, b_image_max)
image_min = torch.maximum(image_min, b_image_min)
state_max = torch.maximum(state_max, b_state_max)
state_min = torch.maximum(state_min, b_state_min)
action_max = torch.maximum(action_max, b_action_max)
action_min = torch.maximum(action_min, b_action_min)
if i < num_batch - 1:
batch = rb.sample()
image_std = torch.sqrt(image_std / num_batch)
state_std = torch.sqrt(state_std / num_batch)
action_std = torch.sqrt(action_std / num_batch)
stats = TensorDict(
{
("observation", "image", "mean"): image_mean[None, :, None, None],
("observation", "image", "std"): image_std[None, :, None, None],
("observation", "image", "max"): image_max[None, :, None, None],
("observation", "image", "min"): image_min[None, :, None, None],
("observation", "state", "mean"): state_mean[None, :],
("observation", "state", "std"): state_std[None, :],
("observation", "state", "max"): state_max[None, :],
("observation", "state", "min"): state_min[None, :],
("action", "mean"): action_mean[None, :],
("action", "std"): action_std[None, :],
("action", "max"): action_max[None, :],
("action", "min"): action_min[None, :],
},
batch_size=[],
)
stats["next", "observation", "image"] = stats["observation", "image"]
stats["next", "observation", "state"] = stats["observation", "state"]
return stats
def _compute_or_load_stats(self, storage) -> TensorDict:
stats_path = self.root / self.dataset_id / "stats.pth"
if stats_path.exists():
stats = torch.load(stats_path)
else:
logging.info(f"compute_stats and save to {stats_path}")
stats = self._compute_stats(storage)
torch.save(stats, stats_path)
return stats

View File

@ -1,4 +1,3 @@
import os
import pickle import pickle
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable
@ -7,130 +6,52 @@ import torch
import torchrl import torchrl
import tqdm import tqdm
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import ( from torchrl.data.replay_buffers.samplers import (
Sampler,
SliceSampler, SliceSampler,
SliceSamplerWithoutReplacement,
) )
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer from torchrl.data.replay_buffers.writers import Writer
from lerobot.common.datasets.abstract import AbstractExperienceReplay
class SimxarmExperienceReplay(TensorDictReplayBuffer): class SimxarmExperienceReplay(AbstractExperienceReplay):
available_datasets = [ available_datasets = [
"xarm_lift_medium", "xarm_lift_medium",
] ]
def __init__( def __init__(
self, self,
dataset_id, dataset_id: str,
batch_size: int = None, batch_size: int = None,
*, *,
shuffle: bool = True, shuffle: bool = True,
num_slices: int = None,
slice_len: int = None,
pad: float = None,
replacement: bool = None,
streaming: bool = False,
root: Path = None, root: Path = None,
download: bool = False,
sampler: Sampler = None,
writer: Writer = None,
collate_fn: Callable = None,
pin_memory: bool = False, pin_memory: bool = False,
prefetch: int = None, prefetch: int = None,
transform: "torchrl.envs.Transform" = None, # noqa-F821 sampler: SliceSampler = None,
split_trajs: bool = False, collate_fn: Callable = None,
strict_length: bool = True, writer: Writer = None,
transform: "torchrl.envs.Transform" = None,
): ):
self.download = download
if streaming:
raise NotImplementedError
self.streaming = streaming
self.dataset_id = dataset_id
self.split_trajs = split_trajs
self.shuffle = shuffle
self.num_slices = num_slices
self.slice_len = slice_len
self.pad = pad
self.strict_length = strict_length
if (self.num_slices is not None) and (self.slice_len is not None):
raise ValueError("num_slices or slice_len can be not None, but not both.")
if split_trajs:
raise NotImplementedError
if root is None:
root = _get_root_dir("simxarm")
os.makedirs(root, exist_ok=True)
self.root = Path(root)
if self.download == "force" or (self.download and not self._is_downloaded()):
storage = self._download_and_preproc()
else:
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
if num_slices is not None or slice_len is not None:
if sampler is not None:
raise ValueError("`num_slices` and `slice_len` are exclusive with the `sampler` argument.")
if replacement:
if not self.shuffle:
raise RuntimeError("shuffle=False can only be used when replacement=False.")
sampler = SliceSampler(
num_slices=num_slices,
slice_len=slice_len,
strict_length=strict_length,
)
else:
sampler = SliceSamplerWithoutReplacement(
num_slices=num_slices,
slice_len=slice_len,
strict_length=strict_length,
shuffle=self.shuffle,
)
if writer is None:
writer = ImmutableDatasetWriter()
if collate_fn is None:
collate_fn = _collate_id
super().__init__( super().__init__(
storage=storage, dataset_id,
sampler=sampler, batch_size,
writer=writer, shuffle=shuffle,
collate_fn=collate_fn, root=root,
pin_memory=pin_memory, pin_memory=pin_memory,
prefetch=prefetch, prefetch=prefetch,
batch_size=batch_size, sampler=sampler,
collate_fn=collate_fn,
writer=writer,
transform=transform, transform=transform,
) )
@property
def num_samples(self):
return len(self)
@property
def num_episodes(self):
return len(self._storage._storage["episode"].unique())
@property
def data_path_root(self):
if self.streaming:
return None
return self.root / self.dataset_id
def _is_downloaded(self):
return os.path.exists(self.data_path_root)
def _download_and_preproc(self): def _download_and_preproc(self):
# download # download
# TODO(rcadene) # TODO(rcadene)
# load dataset_path = self.data_dir / "buffer.pkl"
dataset_dir = Path("data") / self.dataset_id
dataset_path = dataset_dir / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'") print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f: with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f) dataset_dict = pickle.load(f)
@ -172,7 +93,7 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
if episode_id == 0: if episode_id == 0:
# hack to initialize tensordict data structure to store episodes # hack to initialize tensordict data structure to store episodes
td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id) td_data = episode[0].expand(total_frames).memmap_like(self.data_dir)
td_data[idx0:idx1] = episode td_data[idx0:idx1] = episode

View File

@ -6,6 +6,10 @@ from omegaconf import OmegaConf
from termcolor import colored from termcolor import colored
def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
def cfg_to_group(cfg, return_list=False): def cfg_to_group(cfg, return_list=False):
"""Return a wandb-safe group name for logging. Optionally returns group name as list.""" """Return a wandb-safe group name for logging. Optionally returns group name as list."""
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)] # lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]

View File

@ -9,13 +9,13 @@ import numpy as np
import torch import torch
import tqdm import tqdm
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from termcolor import colored
from torchrl.envs import EnvBase from torchrl.envs import EnvBase
from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import set_seed from lerobot.common.utils import init_logging, set_seed
def write_video(video_path, stacked_frames, fps): def write_video(video_path, stacked_frames, fps):
@ -109,10 +109,18 @@ def eval(cfg: dict, out_dir=None):
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
assert torch.cuda.is_available() init_logging()
if cfg.device == "cuda":
assert torch.cuda.is_available()
else:
logging.warning("Using CPU, this will be slow.")
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed) set_seed(cfg.seed)
print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir)
log_output_dir(out_dir)
logging.info("make_offline_buffer") logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(cfg) offline_buffer = make_offline_buffer(cfg)
@ -142,6 +150,8 @@ def eval(cfg: dict, out_dir=None):
) )
print(metrics) print(metrics)
logging.info("End of eval")
if __name__ == "__main__": if __name__ == "__main__":
eval_cli() eval_cli()

View File

@ -4,13 +4,12 @@ import hydra
import numpy as np import numpy as np
import torch import torch
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from termcolor import colored
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers import PrioritizedSliceSampler from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import format_big_number, init_logging, set_seed from lerobot.common.utils import format_big_number, init_logging, set_seed
from lerobot.scripts.eval import eval_policy from lerobot.scripts.eval import eval_policy
@ -164,7 +163,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# log metrics to terminal and wandb # log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg) logger = Logger(out_dir, job_name, cfg)
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}") log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})") logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
logging.info(f"{cfg.online_steps=}") logging.info(f"{cfg.online_steps=}")
@ -212,7 +211,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
for env_step in range(cfg.online_steps): for env_step in range(cfg.online_steps):
if env_step == 0: if env_step == 0:
logging.info("Start online training by interacting with environment") logging.info("Start online training by interacting with environment")
# TODO: use SyncDataCollector for that?
# TODO: add configurable number of rollout? (default=1) # TODO: add configurable number of rollout? (default=1)
with torch.no_grad(): with torch.no_grad():
rollout = env.rollout( rollout = env.rollout(
@ -268,6 +266,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
step += 1 step += 1
online_step += 1 online_step += 1
logging.info("End of training")
if __name__ == "__main__": if __name__ == "__main__":
train_cli() train_cli()

View File

@ -1,13 +1,20 @@
import logging
import threading
from pathlib import Path from pathlib import Path
import einops
import hydra import hydra
import imageio import imageio
import torch import torch
from torchrl.data.replay_buffers import SliceSamplerWithoutReplacement from torchrl.data.replay_buffers import (
SamplerWithoutReplacement,
)
from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.logger import log_output_dir
from lerobot.common.utils import init_logging
NUM_EPISODES_TO_RENDER = 10 NUM_EPISODES_TO_RENDER = 50
MAX_NUM_STEPS = 1000 MAX_NUM_STEPS = 1000
FIRST_FRAME = 0 FIRST_FRAME = 0
@ -17,45 +24,88 @@ def visualize_dataset_cli(cfg: dict):
visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir) visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
def cat_and_write_video(video_path, frames, fps):
frames = torch.cat(frames)
assert frames.dtype == torch.uint8
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
imageio.mimsave(video_path, frames, fps=fps)
def visualize_dataset(cfg: dict, out_dir=None): def visualize_dataset(cfg: dict, out_dir=None):
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
sampler = SliceSamplerWithoutReplacement( init_logging()
num_slices=1, log_output_dir(out_dir)
strict_length=False,
# we expect frames of each episode to be stored next to each others sequentially
sampler = SamplerWithoutReplacement(
shuffle=False, shuffle=False,
) )
offline_buffer = make_offline_buffer(cfg, sampler) logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(
cfg, overwrite_sampler=sampler, normalize=False, overwrite_batch_size=1, overwrite_prefetch=12
)
for _ in range(NUM_EPISODES_TO_RENDER): logging.info("Start rendering episodes from offline buffer")
episode = offline_buffer.sample(MAX_NUM_STEPS)
ep_idx = episode["episode"][FIRST_FRAME].item() threads = []
ep_frames = torch.cat( frames = {}
[ current_ep_idx = 0
episode["observation"]["image"][FIRST_FRAME][None, ...], logging.info(f"Visualizing episode {current_ep_idx}")
episode["next", "observation"]["image"], for _ in range(MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER):
], # TODO(rcadene): make it work with bsize > 1
dim=0, ep_td = offline_buffer.sample(1)
) ep_idx = ep_td["episode"][FIRST_FRAME].item()
video_dir = Path(out_dir) / "visualize_dataset" # TODO(rcaene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
video_dir.mkdir(parents=True, exist_ok=True) no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
# TODO(rcadene): make fps configurable new_episode = ep_idx != current_ep_idx
video_path = video_dir / f"episode_{ep_idx}.mp4"
assert ep_frames.min().item() >= 0 if new_episode:
assert ep_frames.max().item() > 1, "Not mendatory, but sanity check" logging.info(f"Visualizing episode {current_ep_idx}")
assert ep_frames.max().item() <= 255
ep_frames = ep_frames.type(torch.uint8)
imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps)
# ran out of episodes for im_key in offline_buffer.image_keys:
if offline_buffer._sampler._sample_list.numel() == 0: if new_episode or no_more_frames:
# append last observed frames (the ones after last action taken)
frames[im_key].append(ep_td[("next", *im_key)])
video_dir = Path(out_dir) / "visualize_dataset"
video_dir.mkdir(parents=True, exist_ok=True)
if len(offline_buffer.image_keys) > 1:
camera = im_key[-1]
video_path = video_dir / f"episode_{current_ep_idx}_{camera}.mp4"
else:
video_path = video_dir / f"episode_{current_ep_idx}.mp4"
thread = threading.Thread(
target=cat_and_write_video,
args=(str(video_path), frames[im_key], cfg.fps),
)
thread.start()
threads.append(thread)
current_ep_idx = ep_idx
# reset list of frames
del frames[im_key]
# append current cameras images to list of frames
if im_key not in frames:
frames[im_key] = []
frames[im_key].append(ep_td[im_key])
if no_more_frames:
logging.info("Ran out of frames")
break break
for thread in threads:
thread.join()
logging.info("End of visualize_dataset")
if __name__ == "__main__": if __name__ == "__main__":
visualize_dataset_cli() visualize_dataset_cli()