Add cfg.offline_prioritized_sampler
This commit is contained in:
parent
570f8d01df
commit
a027f4edfb
|
@ -1,8 +1,9 @@
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
|
||||||
|
|
||||||
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
||||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||||
|
@ -50,13 +51,22 @@ def make_offline_buffer(cfg, sampler=None):
|
||||||
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
|
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
|
||||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||||
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||||
sampler = PrioritizedSliceSampler(
|
|
||||||
max_capacity=100_000,
|
if cfg.offline_prioritized_sampler:
|
||||||
alpha=cfg.policy.per_alpha,
|
logging.info("use prioritized sampler for offline dataset")
|
||||||
beta=cfg.policy.per_beta,
|
sampler = PrioritizedSliceSampler(
|
||||||
num_slices=num_traj_per_batch,
|
max_capacity=100_000,
|
||||||
strict_length=False,
|
alpha=cfg.policy.per_alpha,
|
||||||
)
|
beta=cfg.policy.per_beta,
|
||||||
|
num_slices=num_traj_per_batch,
|
||||||
|
strict_length=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("use simple sampler for offline dataset")
|
||||||
|
sampler = SliceSampler(
|
||||||
|
num_slices=num_traj_per_batch,
|
||||||
|
strict_length=False,
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.env.name == "simxarm":
|
if cfg.env.name == "simxarm":
|
||||||
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
|
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
|
||||||
|
|
|
@ -29,6 +29,8 @@ log_freq: 250
|
||||||
offline_steps: 1344000
|
offline_steps: 1344000
|
||||||
online_steps: 0
|
online_steps: 0
|
||||||
|
|
||||||
|
offline_prioritized_sampler: true
|
||||||
|
|
||||||
policy:
|
policy:
|
||||||
name: diffusion
|
name: diffusion
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue