diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 3e0e2c32..e9613310 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -22,7 +22,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): batch_size: int = None, *, shuffle: bool = True, - root: Path = None, + root: Path | None = None, pin_memory: bool = False, prefetch: int = None, sampler: SliceSampler = None, @@ -32,7 +32,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): ): self.dataset_id = dataset_id self.shuffle = shuffle - self.root = root if root is None else Path(root) + self.root = root storage = self._download_or_load_dataset() super().__init__( diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 2ea4b831..52a5676e 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -87,7 +87,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay): batch_size: int = None, *, shuffle: bool = True, - root: Path = None, + root: Path | None = None, pin_memory: bool = False, prefetch: int = None, sampler: SliceSampler = None, diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 876b6a50..3f4772c4 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,13 +1,16 @@ import logging import os +from pathlib import Path import torch from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler from lerobot.common.envs.transforms import NormalizeTransform, Prod -# used for unit tests -DATA_DIR = os.environ.get("DATA_DIR", None) +# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and +# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data` +# to load a subset of our datasets for faster continuous integration. +DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None def make_offline_buffer( diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index bac742d9..f4f6d9ac 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -90,7 +90,7 @@ class PushtExperienceReplay(AbstractExperienceReplay): batch_size: int = None, *, shuffle: bool = True, - root: Path = None, + root: Path | None = None, pin_memory: bool = False, prefetch: int = None, sampler: SliceSampler = None, diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index b4dd824f..7bcb03fb 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -43,7 +43,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay): batch_size: int = None, *, shuffle: bool = True, - root: Path = None, + root: Path | None = None, pin_memory: bool = False, prefetch: int = None, sampler: SliceSampler = None,