From 5805a7ffb110281e63db0463c1fb9c9b57bca885 Mon Sep 17 00:00:00 2001 From: Cadene Date: Fri, 15 Mar 2024 12:44:52 +0000 Subject: [PATCH] small fix in type + comments --- lerobot/common/datasets/abstract.py | 4 ++-- lerobot/common/datasets/aloha.py | 2 +- lerobot/common/datasets/factory.py | 7 +++++-- lerobot/common/datasets/pusht.py | 2 +- lerobot/common/datasets/simxarm.py | 2 +- 5 files changed, 10 insertions(+), 7 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 3e0e2c32..e9613310 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -22,7 +22,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): batch_size: int = None, *, shuffle: bool = True, - root: Path = None, + root: Path | None = None, pin_memory: bool = False, prefetch: int = None, sampler: SliceSampler = None, @@ -32,7 +32,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): ): self.dataset_id = dataset_id self.shuffle = shuffle - self.root = root if root is None else Path(root) + self.root = root storage = self._download_or_load_dataset() super().__init__( diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 2ea4b831..52a5676e 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -87,7 +87,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay): batch_size: int = None, *, shuffle: bool = True, - root: Path = None, + root: Path | None = None, pin_memory: bool = False, prefetch: int = None, sampler: SliceSampler = None, diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 876b6a50..3f4772c4 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,13 +1,16 @@ import logging import os +from pathlib import Path import torch from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler from lerobot.common.envs.transforms import NormalizeTransform, Prod -# used for unit tests -DATA_DIR = os.environ.get("DATA_DIR", None) +# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and +# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data` +# to load a subset of our datasets for faster continuous integration. +DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None def make_offline_buffer( diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index bac742d9..f4f6d9ac 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -90,7 +90,7 @@ class PushtExperienceReplay(AbstractExperienceReplay): batch_size: int = None, *, shuffle: bool = True, - root: Path = None, + root: Path | None = None, pin_memory: bool = False, prefetch: int = None, sampler: SliceSampler = None, diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index b4dd824f..7bcb03fb 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -43,7 +43,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay): batch_size: int = None, *, shuffle: bool = True, - root: Path = None, + root: Path | None = None, pin_memory: bool = False, prefetch: int = None, sampler: SliceSampler = None,