Refactor datasets with abstract class
This commit is contained in:
parent
e132a267aa
commit
d4e0849970
|
@ -0,0 +1,185 @@
|
|||
import abc
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
|
||||
|
||||
class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
transform: "torchrl.envs.Transform" = None, # noqa-F821
|
||||
):
|
||||
self.dataset_id = dataset_id
|
||||
self.shuffle = shuffle
|
||||
self.root = _get_root_dir(self.dataset_id) if root is None else root
|
||||
self.root = Path(self.root)
|
||||
self.data_dir = self.root / self.dataset_id
|
||||
|
||||
storage = self._download_or_load_storage()
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
sampler=sampler,
|
||||
writer=ImmutableDatasetWriter() if writer is None else writer,
|
||||
collate_fn=_collate_id if collate_fn is None else collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
batch_size=batch_size,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self._storage._storage["episode"].unique())
|
||||
|
||||
def set_transform(self, transform):
|
||||
self.transform = transform
|
||||
|
||||
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
|
||||
stats_path = self.data_dir / "stats.pth"
|
||||
if stats_path.exists():
|
||||
stats = torch.load(stats_path)
|
||||
else:
|
||||
logging.info(f"compute_stats and save to {stats_path}")
|
||||
stats = self._compute_stats(self._storage._storage, num_batch, batch_size)
|
||||
torch.save(stats, stats_path)
|
||||
return stats
|
||||
|
||||
@abc.abstractmethod
|
||||
def _download_and_preproc(self) -> torch.StorageBase:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _download_or_load_storage(self):
|
||||
if not self._is_downloaded():
|
||||
storage = self._download_and_preproc()
|
||||
else:
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.data_dir))
|
||||
return storage
|
||||
|
||||
def _is_downloaded(self) -> bool:
|
||||
return self.data_dir.is_dir()
|
||||
|
||||
def _compute_stats(self, storage, num_batch=100, batch_size=32):
|
||||
rb = TensorDictReplayBuffer(
|
||||
storage=storage,
|
||||
batch_size=batch_size,
|
||||
prefetch=True,
|
||||
)
|
||||
batch = rb.sample()
|
||||
|
||||
image_channels = batch["observation", "image"].shape[1]
|
||||
image_mean = torch.zeros(image_channels)
|
||||
image_std = torch.zeros(image_channels)
|
||||
image_max = torch.tensor([-math.inf] * image_channels)
|
||||
image_min = torch.tensor([math.inf] * image_channels)
|
||||
|
||||
state_channels = batch["observation", "state"].shape[1]
|
||||
state_mean = torch.zeros(state_channels)
|
||||
state_std = torch.zeros(state_channels)
|
||||
state_max = torch.tensor([-math.inf] * state_channels)
|
||||
state_min = torch.tensor([math.inf] * state_channels)
|
||||
|
||||
action_channels = batch["action"].shape[1]
|
||||
action_mean = torch.zeros(action_channels)
|
||||
action_std = torch.zeros(action_channels)
|
||||
action_max = torch.tensor([-math.inf] * action_channels)
|
||||
action_min = torch.tensor([math.inf] * action_channels)
|
||||
|
||||
for _ in tqdm.tqdm(range(num_batch)):
|
||||
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
||||
state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
||||
action_mean += einops.reduce(batch["action"], "b c -> c", "mean")
|
||||
|
||||
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
||||
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
||||
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
||||
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
||||
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
||||
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
||||
image_max = torch.maximum(image_max, b_image_max)
|
||||
image_min = torch.maximum(image_min, b_image_min)
|
||||
state_max = torch.maximum(state_max, b_state_max)
|
||||
state_min = torch.maximum(state_min, b_state_min)
|
||||
action_max = torch.maximum(action_max, b_action_max)
|
||||
action_min = torch.maximum(action_min, b_action_min)
|
||||
|
||||
batch = rb.sample()
|
||||
|
||||
image_mean /= num_batch
|
||||
state_mean /= num_batch
|
||||
action_mean /= num_batch
|
||||
|
||||
for i in tqdm.tqdm(range(num_batch)):
|
||||
b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
||||
b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
||||
b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean")
|
||||
image_std += (b_image_mean - image_mean) ** 2
|
||||
state_std += (b_state_mean - state_mean) ** 2
|
||||
action_std += (b_action_mean - action_mean) ** 2
|
||||
|
||||
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
||||
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
||||
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
||||
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
||||
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
||||
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
||||
image_max = torch.maximum(image_max, b_image_max)
|
||||
image_min = torch.maximum(image_min, b_image_min)
|
||||
state_max = torch.maximum(state_max, b_state_max)
|
||||
state_min = torch.maximum(state_min, b_state_min)
|
||||
action_max = torch.maximum(action_max, b_action_max)
|
||||
action_min = torch.maximum(action_min, b_action_min)
|
||||
|
||||
if i < num_batch - 1:
|
||||
batch = rb.sample()
|
||||
|
||||
image_std = torch.sqrt(image_std / num_batch)
|
||||
state_std = torch.sqrt(state_std / num_batch)
|
||||
action_std = torch.sqrt(action_std / num_batch)
|
||||
|
||||
stats = TensorDict(
|
||||
{
|
||||
("observation", "image", "mean"): image_mean[None, :, None, None],
|
||||
("observation", "image", "std"): image_std[None, :, None, None],
|
||||
("observation", "image", "max"): image_max[None, :, None, None],
|
||||
("observation", "image", "min"): image_min[None, :, None, None],
|
||||
("observation", "state", "mean"): state_mean[None, :],
|
||||
("observation", "state", "std"): state_std[None, :],
|
||||
("observation", "state", "max"): state_max[None, :],
|
||||
("observation", "state", "min"): state_min[None, :],
|
||||
("action", "mean"): action_mean[None, :],
|
||||
("action", "std"): action_std[None, :],
|
||||
("action", "max"): action_max[None, :],
|
||||
("action", "min"): action_min[None, :],
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
stats["next", "observation", "image"] = stats["observation", "image"]
|
||||
stats["next", "observation", "state"] = stats["observation", "state"]
|
||||
return stats
|
|
@ -5,31 +5,10 @@ from pathlib import Path
|
|||
import torch
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
|
||||
|
||||
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||
from lerobot.common.envs.transforms import NormalizeTransform
|
||||
|
||||
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
|
||||
|
||||
# TODO(rcadene): implement
|
||||
|
||||
# dataset_d4rl = D4RLExperienceReplay(
|
||||
# dataset_id="maze2d-umaze-v1",
|
||||
# split_trajs=False,
|
||||
# batch_size=1,
|
||||
# sampler=SamplerWithoutReplacement(drop_last=False),
|
||||
# prefetch=4,
|
||||
# direct_download=True,
|
||||
# )
|
||||
|
||||
# dataset_openx = OpenXExperienceReplay(
|
||||
# "cmu_stretch",
|
||||
# batch_size=1,
|
||||
# num_slices=1,
|
||||
# #download="force",
|
||||
# streaming=False,
|
||||
# root="data",
|
||||
# )
|
||||
|
||||
|
||||
def make_offline_buffer(cfg, sampler=None):
|
||||
if cfg.policy.balanced_sampling:
|
||||
|
@ -69,31 +48,49 @@ def make_offline_buffer(cfg, sampler=None):
|
|||
)
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
|
||||
offline_buffer = SimxarmExperienceReplay(
|
||||
f"xarm_{cfg.env.task}_medium",
|
||||
# download="force",
|
||||
download=True,
|
||||
streaming=False,
|
||||
root=str(DATA_DIR),
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||
|
||||
clsfunc = SimxarmExperienceReplay
|
||||
dataset_id = f"xarm_{cfg.env.task}_medium"
|
||||
|
||||
elif cfg.env.name == "pusht":
|
||||
offline_buffer = PushtExperienceReplay(
|
||||
"pusht",
|
||||
streaming=False,
|
||||
root=DATA_DIR,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
||||
|
||||
clsfunc = PushtExperienceReplay
|
||||
dataset_id = "pusht"
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
offline_buffer = clsfunc(
|
||||
dataset_id=dataset_id,
|
||||
root=DATA_DIR,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats()
|
||||
in_keys = [("observation", "state"), ("action")]
|
||||
|
||||
if cfg.policy == "tdmpc":
|
||||
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
|
||||
in_keys.append(("observation", "image"))
|
||||
# since we use next observations in tdmpc
|
||||
in_keys.append(("next", "observation", "image"))
|
||||
in_keys.append(("next", "observation", "state"))
|
||||
|
||||
if cfg.policy == "diffusion" and cfg.env.name == "pusht":
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
transform = NormalizeTransform(stats, in_keys, mode="min_max")
|
||||
offline_buffer.set_transform(transform)
|
||||
|
||||
if not overwrite_sampler:
|
||||
num_steps = len(offline_buffer)
|
||||
index = torch.arange(0, num_steps, 1)
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
import logging
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
|
@ -12,16 +9,14 @@ import torch
|
|||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip
|
||||
from lerobot.common.envs.transforms import NormalizeTransform
|
||||
|
||||
# as define in env
|
||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||
|
@ -87,114 +82,36 @@ def add_tee(
|
|||
return body
|
||||
|
||||
|
||||
class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
class PushtExperienceReplay(AbstractExperienceReplay):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
num_slices: int = None,
|
||||
slice_len: int = None,
|
||||
pad: float = None,
|
||||
replacement: bool = None,
|
||||
streaming: bool = False,
|
||||
root: Path = None,
|
||||
sampler: Sampler = None,
|
||||
writer: Writer = None,
|
||||
collate_fn: Callable = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
transform: "torchrl.envs.Transform" = None, # noqa: F821
|
||||
split_trajs: bool = False,
|
||||
strict_length: bool = True,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
transform: "torchrl.envs.Transform" = None, # noqa-F821
|
||||
):
|
||||
if streaming:
|
||||
raise NotImplementedError
|
||||
self.streaming = streaming
|
||||
self.dataset_id = dataset_id
|
||||
self.split_trajs = split_trajs
|
||||
self.shuffle = shuffle
|
||||
self.num_slices = num_slices
|
||||
self.slice_len = slice_len
|
||||
self.pad = pad
|
||||
|
||||
self.strict_length = strict_length
|
||||
if (self.num_slices is not None) and (self.slice_len is not None):
|
||||
raise ValueError("num_slices or slice_len can be not None, but not both.")
|
||||
if split_trajs:
|
||||
raise NotImplementedError
|
||||
|
||||
if root is None:
|
||||
root = _get_root_dir("pusht")
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
self.root = root
|
||||
if not self._is_downloaded():
|
||||
storage = self._download_and_preproc()
|
||||
else:
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||
|
||||
stats = self._compute_or_load_stats(storage)
|
||||
transform = NormalizeTransform(
|
||||
stats,
|
||||
in_keys=[
|
||||
# TODO(rcadene): imagenet normalization is applied inside diffusion policy
|
||||
# We need to automate this for tdmpc and others
|
||||
# ("observation", "image"),
|
||||
("observation", "state"),
|
||||
# TODO(rcadene): for tdmpc, we might want next image and state
|
||||
# ("next", "observation", "image"),
|
||||
# ("next", "observation", "state"),
|
||||
("action"),
|
||||
],
|
||||
mode="min_max",
|
||||
)
|
||||
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||
transform.stats["observation", "state", "min"] = torch.tensor(
|
||||
[13.456424, 32.938293], dtype=torch.float32
|
||||
)
|
||||
transform.stats["observation", "state", "max"] = torch.tensor(
|
||||
[496.14618, 510.9579], dtype=torch.float32
|
||||
)
|
||||
transform.stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
transform.stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
if writer is None:
|
||||
writer = ImmutableDatasetWriter()
|
||||
if collate_fn is None:
|
||||
collate_fn = _collate_id
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
sampler=sampler,
|
||||
writer=writer,
|
||||
collate_fn=collate_fn,
|
||||
dataset_id,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self._storage._storage["episode"].unique())
|
||||
|
||||
@property
|
||||
def data_path_root(self) -> Path:
|
||||
return None if self.streaming else self.root / self.dataset_id
|
||||
|
||||
def _is_downloaded(self) -> bool:
|
||||
return self.data_path_root.is_dir()
|
||||
|
||||
def _download_and_preproc(self):
|
||||
# download
|
||||
raw_dir = self.root / "raw"
|
||||
raw_dir = self.data_dir / "raw"
|
||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||
if not zarr_path.is_dir():
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -286,7 +203,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id)
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.data_dir)
|
||||
|
||||
td_data[idxtd : idxtd + len(episode)] = episode
|
||||
|
||||
|
@ -294,112 +211,3 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
|||
idxtd = idxtd + len(episode)
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
|
||||
def _compute_stats(self, storage, num_batch=100, batch_size=32):
|
||||
rb = TensorDictReplayBuffer(
|
||||
storage=storage,
|
||||
batch_size=batch_size,
|
||||
prefetch=True,
|
||||
)
|
||||
batch = rb.sample()
|
||||
|
||||
image_channels = batch["observation", "image"].shape[1]
|
||||
image_mean = torch.zeros(image_channels)
|
||||
image_std = torch.zeros(image_channels)
|
||||
image_max = torch.tensor([-math.inf] * image_channels)
|
||||
image_min = torch.tensor([math.inf] * image_channels)
|
||||
|
||||
state_channels = batch["observation", "state"].shape[1]
|
||||
state_mean = torch.zeros(state_channels)
|
||||
state_std = torch.zeros(state_channels)
|
||||
state_max = torch.tensor([-math.inf] * state_channels)
|
||||
state_min = torch.tensor([math.inf] * state_channels)
|
||||
|
||||
action_channels = batch["action"].shape[1]
|
||||
action_mean = torch.zeros(action_channels)
|
||||
action_std = torch.zeros(action_channels)
|
||||
action_max = torch.tensor([-math.inf] * action_channels)
|
||||
action_min = torch.tensor([math.inf] * action_channels)
|
||||
|
||||
for _ in tqdm.tqdm(range(num_batch)):
|
||||
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
||||
state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
||||
action_mean += einops.reduce(batch["action"], "b c -> c", "mean")
|
||||
|
||||
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
||||
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
||||
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
||||
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
||||
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
||||
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
||||
image_max = torch.maximum(image_max, b_image_max)
|
||||
image_min = torch.maximum(image_min, b_image_min)
|
||||
state_max = torch.maximum(state_max, b_state_max)
|
||||
state_min = torch.maximum(state_min, b_state_min)
|
||||
action_max = torch.maximum(action_max, b_action_max)
|
||||
action_min = torch.maximum(action_min, b_action_min)
|
||||
|
||||
batch = rb.sample()
|
||||
|
||||
image_mean /= num_batch
|
||||
state_mean /= num_batch
|
||||
action_mean /= num_batch
|
||||
|
||||
for i in tqdm.tqdm(range(num_batch)):
|
||||
b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
||||
b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
||||
b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean")
|
||||
image_std += (b_image_mean - image_mean) ** 2
|
||||
state_std += (b_state_mean - state_mean) ** 2
|
||||
action_std += (b_action_mean - action_mean) ** 2
|
||||
|
||||
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
||||
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
||||
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
||||
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
||||
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
||||
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
||||
image_max = torch.maximum(image_max, b_image_max)
|
||||
image_min = torch.maximum(image_min, b_image_min)
|
||||
state_max = torch.maximum(state_max, b_state_max)
|
||||
state_min = torch.maximum(state_min, b_state_min)
|
||||
action_max = torch.maximum(action_max, b_action_max)
|
||||
action_min = torch.maximum(action_min, b_action_min)
|
||||
|
||||
if i < num_batch - 1:
|
||||
batch = rb.sample()
|
||||
|
||||
image_std = torch.sqrt(image_std / num_batch)
|
||||
state_std = torch.sqrt(state_std / num_batch)
|
||||
action_std = torch.sqrt(action_std / num_batch)
|
||||
|
||||
stats = TensorDict(
|
||||
{
|
||||
("observation", "image", "mean"): image_mean[None, :, None, None],
|
||||
("observation", "image", "std"): image_std[None, :, None, None],
|
||||
("observation", "image", "max"): image_max[None, :, None, None],
|
||||
("observation", "image", "min"): image_min[None, :, None, None],
|
||||
("observation", "state", "mean"): state_mean[None, :],
|
||||
("observation", "state", "std"): state_std[None, :],
|
||||
("observation", "state", "max"): state_max[None, :],
|
||||
("observation", "state", "min"): state_min[None, :],
|
||||
("action", "mean"): action_mean[None, :],
|
||||
("action", "std"): action_std[None, :],
|
||||
("action", "max"): action_max[None, :],
|
||||
("action", "min"): action_min[None, :],
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
stats["next", "observation", "image"] = stats["observation", "image"]
|
||||
stats["next", "observation", "state"] = stats["observation", "state"]
|
||||
return stats
|
||||
|
||||
def _compute_or_load_stats(self, storage) -> TensorDict:
|
||||
stats_path = self.root / self.dataset_id / "stats.pth"
|
||||
if stats_path.exists():
|
||||
stats = torch.load(stats_path)
|
||||
else:
|
||||
logging.info(f"compute_stats and save to {stats_path}")
|
||||
stats = self._compute_stats(storage)
|
||||
torch.save(stats, stats_path)
|
||||
return stats
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
@ -7,130 +6,52 @@ import torch
|
|||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import (
|
||||
Sampler,
|
||||
SliceSampler,
|
||||
SliceSamplerWithoutReplacement,
|
||||
)
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||
|
||||
|
||||
class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
||||
class SimxarmExperienceReplay(AbstractExperienceReplay):
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id,
|
||||
dataset_id: str,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
num_slices: int = None,
|
||||
slice_len: int = None,
|
||||
pad: float = None,
|
||||
replacement: bool = None,
|
||||
streaming: bool = False,
|
||||
root: Path = None,
|
||||
download: bool = False,
|
||||
sampler: Sampler = None,
|
||||
writer: Writer = None,
|
||||
collate_fn: Callable = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
transform: "torchrl.envs.Transform" = None, # noqa-F821
|
||||
split_trajs: bool = False,
|
||||
strict_length: bool = True,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
self.download = download
|
||||
if streaming:
|
||||
raise NotImplementedError
|
||||
self.streaming = streaming
|
||||
self.dataset_id = dataset_id
|
||||
self.split_trajs = split_trajs
|
||||
self.shuffle = shuffle
|
||||
self.num_slices = num_slices
|
||||
self.slice_len = slice_len
|
||||
self.pad = pad
|
||||
|
||||
self.strict_length = strict_length
|
||||
if (self.num_slices is not None) and (self.slice_len is not None):
|
||||
raise ValueError("num_slices or slice_len can be not None, but not both.")
|
||||
if split_trajs:
|
||||
raise NotImplementedError
|
||||
|
||||
if root is None:
|
||||
root = _get_root_dir("simxarm")
|
||||
os.makedirs(root, exist_ok=True)
|
||||
self.root = Path(root)
|
||||
if self.download == "force" or (self.download and not self._is_downloaded()):
|
||||
storage = self._download_and_preproc()
|
||||
else:
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||
|
||||
if num_slices is not None or slice_len is not None:
|
||||
if sampler is not None:
|
||||
raise ValueError("`num_slices` and `slice_len` are exclusive with the `sampler` argument.")
|
||||
|
||||
if replacement:
|
||||
if not self.shuffle:
|
||||
raise RuntimeError("shuffle=False can only be used when replacement=False.")
|
||||
sampler = SliceSampler(
|
||||
num_slices=num_slices,
|
||||
slice_len=slice_len,
|
||||
strict_length=strict_length,
|
||||
)
|
||||
else:
|
||||
sampler = SliceSamplerWithoutReplacement(
|
||||
num_slices=num_slices,
|
||||
slice_len=slice_len,
|
||||
strict_length=strict_length,
|
||||
shuffle=self.shuffle,
|
||||
)
|
||||
|
||||
if writer is None:
|
||||
writer = ImmutableDatasetWriter()
|
||||
if collate_fn is None:
|
||||
collate_fn = _collate_id
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
sampler=sampler,
|
||||
writer=writer,
|
||||
collate_fn=collate_fn,
|
||||
dataset_id,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@property
|
||||
def num_samples(self):
|
||||
return len(self)
|
||||
|
||||
@property
|
||||
def num_episodes(self):
|
||||
return len(self._storage._storage["episode"].unique())
|
||||
|
||||
@property
|
||||
def data_path_root(self):
|
||||
if self.streaming:
|
||||
return None
|
||||
return self.root / self.dataset_id
|
||||
|
||||
def _is_downloaded(self):
|
||||
return os.path.exists(self.data_path_root)
|
||||
|
||||
def _download_and_preproc(self):
|
||||
# download
|
||||
# TODO(rcadene)
|
||||
|
||||
# load
|
||||
dataset_dir = Path("data") / self.dataset_id
|
||||
dataset_path = dataset_dir / "buffer.pkl"
|
||||
dataset_path = self.data_dir / "buffer.pkl"
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
|
@ -172,7 +93,7 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
|||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id)
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.data_dir)
|
||||
|
||||
td_data[idx0:idx1] = episode
|
||||
|
||||
|
|
Loading…
Reference in New Issue