small fix in type + comments

This commit is contained in:
Cadene 2024-03-15 12:44:52 +00:00
parent 41521f7e96
commit 5805a7ffb1
5 changed files with 10 additions and 7 deletions

View File

@ -22,7 +22,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
batch_size: int = None,
*,
shuffle: bool = True,
root: Path = None,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
@ -32,7 +32,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
):
self.dataset_id = dataset_id
self.shuffle = shuffle
self.root = root if root is None else Path(root)
self.root = root
storage = self._download_or_load_dataset()
super().__init__(

View File

@ -87,7 +87,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
batch_size: int = None,
*,
shuffle: bool = True,
root: Path = None,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,

View File

@ -1,13 +1,16 @@
import logging
import os
from pathlib import Path
import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
from lerobot.common.envs.transforms import NormalizeTransform, Prod
# used for unit tests
DATA_DIR = os.environ.get("DATA_DIR", None)
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
# to load a subset of our datasets for faster continuous integration.
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
def make_offline_buffer(

View File

@ -90,7 +90,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
batch_size: int = None,
*,
shuffle: bool = True,
root: Path = None,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,

View File

@ -43,7 +43,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
batch_size: int = None,
*,
shuffle: bool = True,
root: Path = None,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,