small fix in type + comments
This commit is contained in:
parent
41521f7e96
commit
5805a7ffb1
|
@ -22,7 +22,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
|||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path = None,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
|
@ -32,7 +32,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
|||
):
|
||||
self.dataset_id = dataset_id
|
||||
self.shuffle = shuffle
|
||||
self.root = root if root is None else Path(root)
|
||||
self.root = root
|
||||
storage = self._download_or_load_dataset()
|
||||
|
||||
super().__init__(
|
||||
|
|
|
@ -87,7 +87,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
|||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path = None,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
|
||||
|
||||
from lerobot.common.envs.transforms import NormalizeTransform, Prod
|
||||
|
||||
# used for unit tests
|
||||
DATA_DIR = os.environ.get("DATA_DIR", None)
|
||||
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
||||
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
|
||||
# to load a subset of our datasets for faster continuous integration.
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
||||
def make_offline_buffer(
|
||||
|
|
|
@ -90,7 +90,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
|
|||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path = None,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
|
|
|
@ -43,7 +43,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
|
|||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path = None,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
|
|
Loading…
Reference in New Issue