diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index b6b93d89..e05fb926 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,8 +1,9 @@ +import logging import os from pathlib import Path import torch -from torchrl.data.replay_buffers import PrioritizedSliceSampler +from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler from lerobot.common.datasets.pusht import PushtExperienceReplay from lerobot.common.datasets.simxarm import SimxarmExperienceReplay @@ -50,13 +51,22 @@ def make_offline_buffer(cfg, sampler=None): num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon # TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size. # We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size. - sampler = PrioritizedSliceSampler( - max_capacity=100_000, - alpha=cfg.policy.per_alpha, - beta=cfg.policy.per_beta, - num_slices=num_traj_per_batch, - strict_length=False, - ) + + if cfg.offline_prioritized_sampler: + logging.info("use prioritized sampler for offline dataset") + sampler = PrioritizedSliceSampler( + max_capacity=100_000, + alpha=cfg.policy.per_alpha, + beta=cfg.policy.per_beta, + num_slices=num_traj_per_batch, + strict_length=False, + ) + else: + logging.info("use simple sampler for offline dataset") + sampler = SliceSampler( + num_slices=num_traj_per_batch, + strict_length=False, + ) if cfg.env.name == "simxarm": # TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 0ea8f638..1adc8a9e 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -29,6 +29,8 @@ log_freq: 250 offline_steps: 1344000 online_steps: 0 +offline_prioritized_sampler: true + policy: name: diffusion