From ccd5dc5a4253b3f284a96b63896e7ed328d60b02 Mon Sep 17 00:00:00 2001 From: Cadene Date: Mon, 11 Mar 2024 12:33:15 +0000 Subject: [PATCH] fix training --- lerobot/common/datasets/abstract.py | 8 +++++++- lerobot/common/envs/factory.py | 10 +++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 1e67e9d8..0c0746e1 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -14,6 +14,8 @@ from torchrl.data.replay_buffers.samplers import SliceSampler from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer +from rl.torchrl.envs.transforms.transforms import Compose + class AbstractExperienceReplay(TensorDictReplayBuffer): def __init__( @@ -78,7 +80,11 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): return self._transform def set_transform(self, transform): - self._transform = transform + if not isinstance(transform, Compose): + # required since torchrl calls `len(self._transform)` downstream + self._transform = Compose(transform) + else: + self._transform = transform def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict: stats_path = self.data_dir / "stats.pth" diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 984b866a..184646cf 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,5 +1,7 @@ from torchrl.envs.transforms import StepCounter, TransformedEnv +from rl.torchrl.envs.transforms.transforms import Compose, Transform + def make_env(cfg, transform=None): kwargs = { @@ -33,7 +35,13 @@ def make_env(cfg, transform=None): if transform is not None: # useful to add normalization - env.append_transform(transform) + if isinstance(transform, Compose): + for tf in transform: + env.append_transform(tf.clone()) + elif isinstance(transform, Transform): + env.append_transform(transform.clone()) + else: + raise NotImplementedError() return env