2024-03-05 07:08:52 +08:00
import logging
2024-03-02 23:53:29 +08:00
import os
2024-03-15 20:44:52 +08:00
from pathlib import Path
2024-03-01 21:31:54 +08:00
2024-02-20 20:26:57 +08:00
import torch
2024-03-05 07:08:52 +08:00
from torchrl . data . replay_buffers import PrioritizedSliceSampler , SliceSampler
2024-02-20 20:26:57 +08:00
2024-03-22 21:25:23 +08:00
from lerobot . common . transforms import NormalizeTransform , Prod
2024-02-20 20:26:57 +08:00
2024-03-15 20:44:52 +08:00
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
# to load a subset of our datasets for faster continuous integration.
DATA_DIR = Path ( os . environ [ " DATA_DIR " ] ) if " DATA_DIR " in os . environ else None
2024-03-01 21:31:54 +08:00
2024-02-20 20:26:57 +08:00
2024-03-06 18:14:03 +08:00
def make_offline_buffer (
2024-03-22 18:26:55 +08:00
cfg ,
overwrite_sampler = None ,
2024-03-22 21:25:23 +08:00
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
2024-03-22 18:26:55 +08:00
normalize = True ,
overwrite_batch_size = None ,
overwrite_prefetch = None ,
stats_path = None ,
2024-03-06 18:14:03 +08:00
) :
2024-02-29 01:45:01 +08:00
if cfg . policy . balanced_sampling :
assert cfg . online_steps > 0
batch_size = None
pin_memory = False
prefetch = None
else :
assert cfg . online_steps == 0
num_slices = cfg . policy . batch_size
batch_size = cfg . policy . horizon * num_slices
pin_memory = cfg . device == " cuda "
prefetch = cfg . prefetch
2024-02-21 08:49:40 +08:00
2024-03-06 18:14:03 +08:00
if overwrite_batch_size is not None :
batch_size = overwrite_batch_size
2024-02-21 08:49:40 +08:00
2024-03-06 18:14:03 +08:00
if overwrite_prefetch is not None :
prefetch = overwrite_prefetch
if overwrite_sampler is None :
2024-02-26 01:42:47 +08:00
# TODO(rcadene): move batch_size outside
num_traj_per_batch = cfg . policy . batch_size # // cfg.horizon
2024-02-21 08:49:40 +08:00
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
2024-03-05 07:08:52 +08:00
if cfg . offline_prioritized_sampler :
logging . info ( " use prioritized sampler for offline dataset " )
sampler = PrioritizedSliceSampler (
max_capacity = 100_000 ,
alpha = cfg . policy . per_alpha ,
beta = cfg . policy . per_beta ,
num_slices = num_traj_per_batch ,
strict_length = False ,
)
else :
logging . info ( " use simple sampler for offline dataset " )
sampler = SliceSampler (
num_slices = num_traj_per_batch ,
strict_length = False ,
)
2024-03-06 18:14:03 +08:00
else :
sampler = overwrite_sampler
2024-02-20 20:26:57 +08:00
2024-02-26 01:42:47 +08:00
if cfg . env . name == " simxarm " :
2024-03-22 21:25:23 +08:00
from lerobot . common . datasets . simxarm import SimxarmDataset
2024-03-05 18:20:57 +08:00
2024-03-22 21:25:23 +08:00
clsfunc = SimxarmDataset
2024-03-05 18:20:57 +08:00
2024-02-26 01:42:47 +08:00
elif cfg . env . name == " pusht " :
2024-03-22 21:25:23 +08:00
from lerobot . common . datasets . pusht import PushtDataset
2024-03-05 18:20:57 +08:00
2024-03-22 21:25:23 +08:00
clsfunc = PushtDataset
2024-03-06 18:15:11 +08:00
elif cfg . env . name == " aloha " :
2024-03-22 21:25:23 +08:00
from lerobot . common . datasets . aloha import AlohaDataset
2024-03-06 18:15:11 +08:00
2024-03-22 21:25:23 +08:00
clsfunc = AlohaDataset
2024-02-20 20:26:57 +08:00
else :
2024-02-26 01:42:47 +08:00
raise ValueError ( cfg . env . name )
2024-02-20 20:26:57 +08:00
2024-03-22 21:25:23 +08:00
# TODO(rcadene): backward compatiblity to load pretrained pusht policy
dataset_id = cfg . get ( " dataset_id " )
if dataset_id is None and cfg . env . name == " pusht " :
dataset_id = " pusht "
2024-03-05 18:20:57 +08:00
offline_buffer = clsfunc (
dataset_id = dataset_id ,
sampler = sampler ,
batch_size = batch_size ,
2024-03-15 08:30:11 +08:00
root = DATA_DIR ,
2024-03-05 18:20:57 +08:00
pin_memory = pin_memory ,
prefetch = prefetch if isinstance ( prefetch , int ) else None ,
)
Add Aloha env and ACT policy
WIP Aloha env tests pass
Rendering works (fps look fast tho? TODO action bounding is too wide [-1,1])
Update README
Copy past from act repo
Remove download.py add a WIP for Simxarm
Remove download.py add a WIP for Simxarm
Add act yaml (TODO: try train.py)
Training can runs (TODO: eval)
Add tasks without end_effector that are compatible with dataset, Eval can run (TODO: training and pretrained model)
Add AbstractEnv, Refactor AlohaEnv, Add rendering_hook in env, Minor modifications, (TODO: Refactor Pusht and Simxarm)
poetry lock
fix bug in compute_stats for action normalization
fix more bugs in normalization
fix training
fix import
PushtEnv inheriates AbstractEnv, Improve factory Normalization
Add _make_env to EnvAbstract
Add call_rendering_hooks to pusht env
SimxarmEnv inherites from AbstractEnv (NOT TESTED)
Add aloha tests artifacts + update pusht stats
fix image normalization: before env was in [0,1] but dataset in [0,255], and now both in [0,255]
Small fix on simxarm
Add next to obs
Add top camera to Aloha env (TODO: make it compatible with set of cameras)
Add top camera to Aloha env (TODO: make it compatible with set of cameras)
2024-03-08 17:47:39 +08:00
if cfg . policy . name == " tdmpc " :
img_keys = [ ]
for key in offline_buffer . image_keys :
img_keys . append ( ( " next " , * key ) )
img_keys + = offline_buffer . image_keys
else :
img_keys = offline_buffer . image_keys
2024-03-26 18:24:46 +08:00
if normalize :
transforms = [ Prod ( in_keys = img_keys , prod = 1 / 255 ) ]
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
# min_max_from_spec
stats = offline_buffer . compute_or_load_stats ( ) if stats_path is None else torch . load ( stats_path )
# we only normalize the state and action, since the images are usually normalized inside the model for
# now (except for tdmpc: see the following)
in_keys = [ ( " observation " , " state " ) , ( " action " ) ]
if cfg . policy . name == " tdmpc " :
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
in_keys + = img_keys
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
in_keys + = [ ( " next " , * key ) for key in img_keys ]
in_keys . append ( ( " next " , " observation " , " state " ) )
if cfg . policy . name == " diffusion " and cfg . env . name == " pusht " :
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats [ " observation " , " state " , " min " ] = torch . tensor ( [ 13.456424 , 32.938293 ] , dtype = torch . float32 )
stats [ " observation " , " state " , " max " ] = torch . tensor ( [ 496.14618 , 510.9579 ] , dtype = torch . float32 )
stats [ " action " , " min " ] = torch . tensor ( [ 12.0 , 25.0 ] , dtype = torch . float32 )
stats [ " action " , " max " ] = torch . tensor ( [ 511.0 , 511.0 ] , dtype = torch . float32 )
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = " mean_std " if cfg . env . name == " aloha " else " min_max "
transforms . append ( NormalizeTransform ( stats , in_keys , mode = normalization_mode ) )
offline_buffer . set_transform ( transforms )
2024-03-05 18:20:57 +08:00
2024-02-21 08:49:40 +08:00
if not overwrite_sampler :
2024-03-06 18:21:22 +08:00
index = torch . arange ( 0 , offline_buffer . num_samples , 1 )
2024-02-21 08:49:40 +08:00
sampler . extend ( index )
2024-02-20 20:26:57 +08:00
return offline_buffer