From b10c9507d4f2df0984b77abcc2948f2cbbb31e9b Mon Sep 17 00:00:00 2001 From: Cadene Date: Fri, 15 Mar 2024 00:36:55 +0000 Subject: [PATCH] Small fix --- lerobot/common/datasets/abstract.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 0e8fcc2b..61a0d25b 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -96,10 +96,10 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): def _download_or_load_dataset(self) -> torch.StorageBase: if self.root is None: - data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset") + self.data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset") else: - data_dir = Path(self.root) / self.dataset_id - return TensorStorage(TensorDict.load_memmap(data_dir)) + self.data_dir = Path(self.root) / self.dataset_id + return TensorStorage(TensorDict.load_memmap(self.data_dir)) def _compute_stats(self, num_batch=100, batch_size=32): rb = TensorDictReplayBuffer(