import logging from pathlib import Path from typing import Callable import einops import torch import torchrl import tqdm from huggingface_hub import snapshot_download from tensordict import TensorDict from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SliceSampler from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer from torchrl.envs.transforms.transforms import Compose class AbstractExperienceReplay(TensorDictReplayBuffer): def __init__( self, dataset_id: str, batch_size: int = None, *, shuffle: bool = True, root: Path | None = None, pin_memory: bool = False, prefetch: int = None, sampler: SliceSampler = None, collate_fn: Callable = None, writer: Writer = None, transform: "torchrl.envs.Transform" = None, ): self.dataset_id = dataset_id self.shuffle = shuffle self.root = root storage = self._download_or_load_dataset() super().__init__( storage=storage, sampler=sampler, writer=ImmutableDatasetWriter() if writer is None else writer, collate_fn=_collate_id if collate_fn is None else collate_fn, pin_memory=pin_memory, prefetch=prefetch, batch_size=batch_size, transform=transform, ) @property def stats_patterns(self) -> dict: return { ("observation", "state"): "b c -> c", ("observation", "image"): "b c h w -> c", ("action",): "b c -> c", } @property def image_keys(self) -> list: return [("observation", "image")] @property def num_cameras(self) -> int: return len(self.image_keys) @property def num_samples(self) -> int: return len(self) @property def num_episodes(self) -> int: return len(self._storage._storage["episode"].unique()) @property def transform(self): return self._transform def set_transform(self, transform): if not isinstance(transform, Compose): # required since torchrl calls `len(self._transform)` downstream if isinstance(transform, list): self._transform = Compose(*transform) else: self._transform = Compose(transform) else: self._transform = transform def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict: stats_path = self.data_dir / "stats.pth" if stats_path.exists(): stats = torch.load(stats_path) else: logging.info(f"compute_stats and save to {stats_path}") stats = self._compute_stats(num_batch, batch_size) torch.save(stats, stats_path) return stats def _download_or_load_dataset(self) -> torch.StorageBase: if self.root is None: self.data_dir = Path(snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset")) else: self.data_dir = self.root / self.dataset_id return TensorStorage(TensorDict.load_memmap(self.data_dir)) def _compute_stats(self, num_batch=100, batch_size=32): rb = TensorDictReplayBuffer( storage=self._storage, batch_size=batch_size, prefetch=True, ) mean, std, max, min = {}, {}, {}, {} # compute mean, min, max for _ in tqdm.tqdm(range(num_batch)): batch = rb.sample() for key, pattern in self.stats_patterns.items(): batch[key] = batch[key].float() if key not in mean: # first batch initialize mean, min, max mean[key] = einops.reduce(batch[key], pattern, "mean") max[key] = einops.reduce(batch[key], pattern, "max") min[key] = einops.reduce(batch[key], pattern, "min") else: mean[key] += einops.reduce(batch[key], pattern, "mean") max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) batch = rb.sample() for key in self.stats_patterns: mean[key] /= num_batch # compute std, min, max for _ in tqdm.tqdm(range(num_batch)): batch = rb.sample() for key, pattern in self.stats_patterns.items(): batch[key] = batch[key].float() batch_mean = einops.reduce(batch[key], pattern, "mean") if key not in std: # first batch initialize std std[key] = (batch_mean - mean[key]) ** 2 else: std[key] += (batch_mean - mean[key]) ** 2 max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) for key in self.stats_patterns: std[key] = torch.sqrt(std[key] / num_batch) stats = TensorDict({}, batch_size=[]) for key in self.stats_patterns: stats[(*key, "mean")] = mean[key] stats[(*key, "std")] = std[key] stats[(*key, "max")] = max[key] stats[(*key, "min")] = min[key] if key[0] == "observation": # use same stats for the next observations stats[("next", *key)] = stats[key] return stats