Add cfg.offline_prioritized_sampler

This commit is contained in:
Remi Cadene 2024-03-04 23:08:52 +00:00
parent 570f8d01df
commit a027f4edfb
2 changed files with 20 additions and 8 deletions

View File

@ -1,8 +1,9 @@
import logging
import os import os
from pathlib import Path from pathlib import Path
import torch import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
from lerobot.common.datasets.pusht import PushtExperienceReplay from lerobot.common.datasets.pusht import PushtExperienceReplay
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
@ -50,13 +51,22 @@ def make_offline_buffer(cfg, sampler=None):
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size. # TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size. # We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
sampler = PrioritizedSliceSampler(
max_capacity=100_000, if cfg.offline_prioritized_sampler:
alpha=cfg.policy.per_alpha, logging.info("use prioritized sampler for offline dataset")
beta=cfg.policy.per_beta, sampler = PrioritizedSliceSampler(
num_slices=num_traj_per_batch, max_capacity=100_000,
strict_length=False, alpha=cfg.policy.per_alpha,
) beta=cfg.policy.per_beta,
num_slices=num_traj_per_batch,
strict_length=False,
)
else:
logging.info("use simple sampler for offline dataset")
sampler = SliceSampler(
num_slices=num_traj_per_batch,
strict_length=False,
)
if cfg.env.name == "simxarm": if cfg.env.name == "simxarm":
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here # TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here

View File

@ -29,6 +29,8 @@ log_freq: 250
offline_steps: 1344000 offline_steps: 1344000
online_steps: 0 online_steps: 0
offline_prioritized_sampler: true
policy: policy:
name: diffusion name: diffusion