Small fix
This commit is contained in:
parent
a311d38796
commit
b10c9507d4
|
@ -96,10 +96,10 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||||
|
|
||||||
def _download_or_load_dataset(self) -> torch.StorageBase:
|
def _download_or_load_dataset(self) -> torch.StorageBase:
|
||||||
if self.root is None:
|
if self.root is None:
|
||||||
data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset")
|
self.data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset")
|
||||||
else:
|
else:
|
||||||
data_dir = Path(self.root) / self.dataset_id
|
self.data_dir = Path(self.root) / self.dataset_id
|
||||||
return TensorStorage(TensorDict.load_memmap(data_dir))
|
return TensorStorage(TensorDict.load_memmap(self.data_dir))
|
||||||
|
|
||||||
def _compute_stats(self, num_batch=100, batch_size=32):
|
def _compute_stats(self, num_batch=100, batch_size=32):
|
||||||
rb = TensorDictReplayBuffer(
|
rb = TensorDictReplayBuffer(
|
||||||
|
|
Loading…
Reference in New Issue