Small fix

This commit is contained in:
Cadene 2024-03-15 00:36:55 +00:00
parent a311d38796
commit b10c9507d4
1 changed files with 3 additions and 3 deletions

View File

@ -96,10 +96,10 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
def _download_or_load_dataset(self) -> torch.StorageBase:
if self.root is None:
data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset")
self.data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset")
else:
data_dir = Path(self.root) / self.dataset_id
return TensorStorage(TensorDict.load_memmap(data_dir))
self.data_dir = Path(self.root) / self.dataset_id
return TensorStorage(TensorDict.load_memmap(self.data_dir))
def _compute_stats(self, num_batch=100, batch_size=32):
rb = TensorDictReplayBuffer(