import torch from torchrl.data.replay_buffers import PrioritizedSliceSampler from lerobot.common.datasets.pusht import PushtExperienceReplay from lerobot.common.datasets.simxarm import SimxarmExperienceReplay # TODO(rcadene): implement # dataset_d4rl = D4RLExperienceReplay( # dataset_id="maze2d-umaze-v1", # split_trajs=False, # batch_size=1, # sampler=SamplerWithoutReplacement(drop_last=False), # prefetch=4, # direct_download=True, # ) # dataset_openx = OpenXExperienceReplay( # "cmu_stretch", # batch_size=1, # num_slices=1, # #download="force", # streaming=False, # root="data", # ) def make_offline_buffer(cfg, sampler=None): if cfg.policy.balanced_sampling: assert cfg.online_steps > 0 batch_size = None pin_memory = False prefetch = None else: assert cfg.online_steps == 0 num_slices = cfg.policy.batch_size batch_size = cfg.policy.horizon * num_slices pin_memory = cfg.device == "cuda" prefetch = cfg.prefetch overwrite_sampler = sampler is not None if not overwrite_sampler: # TODO(rcadene): move batch_size outside num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon # TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size. # We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size. sampler = PrioritizedSliceSampler( max_capacity=100_000, alpha=cfg.policy.per_alpha, beta=cfg.policy.per_beta, num_slices=num_traj_per_batch, strict_length=False, ) if cfg.env.name == "simxarm": # TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here offline_buffer = SimxarmExperienceReplay( f"xarm_{cfg.env.task}_medium", # download="force", download=True, streaming=False, root="data", sampler=sampler, batch_size=batch_size, pin_memory=pin_memory, prefetch=prefetch, ) elif cfg.env.name == "pusht": offline_buffer = PushtExperienceReplay( "pusht", # download="force", # TODO(aliberts): automate download download=False, streaming=False, root="data", sampler=sampler, batch_size=batch_size, pin_memory=pin_memory, prefetch=prefetch, ) else: raise ValueError(cfg.env.name) if not overwrite_sampler: num_steps = len(offline_buffer) index = torch.arange(0, num_steps, 1) sampler.extend(index) return offline_buffer