This commit is contained in:
Thomas Wolf 2024-05-28 11:08:55 +02:00
parent b6c216b590
commit 97cb7a2362
4 changed files with 108 additions and 15 deletions

View File

@ -43,8 +43,7 @@ def get_cameras(hdf5_data):
def check_format(raw_dir) -> bool:
# only frames from simulation are uncompressed
compressed_images = "sim" not in raw_dir.name
compressed_images = None
hdf5_paths = list(raw_dir.glob("episode_*.hdf5"))
assert len(hdf5_paths) != 0
@ -62,18 +61,20 @@ def check_format(raw_dir) -> bool:
for camera in get_cameras(data):
assert num_frames == data[f"/observations/images/{camera}"].shape[0]
if compressed_images:
assert data[f"/observations/images/{camera}"].ndim == 2
assert data[f"/observations/images/{camera}"].ndim in [2, 4]
if data[f"/observations/images/{camera}"].ndim == 2:
assert compressed_images is None or compressed_images
compressed_images = True
else:
assert compressed_images is None or not compressed_images
compressed_images = False
assert data[f"/observations/images/{camera}"].ndim == 4
b, h, w, c = data[f"/observations/images/{camera}"].shape
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
return compressed_images
def load_from_raw(raw_dir, out_dir, fps, video, debug):
# only frames from simulation are uncompressed
compressed_images = "sim" not in raw_dir.name
def load_from_raw(raw_dir, out_dir, fps, video, debug, compressed_images):
hdf5_files = list(raw_dir.glob("*.hdf5"))
ep_dicts = []
episode_data_index = {"from": [], "to": []}
@ -199,12 +200,12 @@ def to_hf_dataset(data_dict, video) -> Dataset:
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
# sanity check
check_format(raw_dir)
compressed_images = check_format(raw_dir)
if fps is None:
fps = 50
data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug, compressed_images)
hf_dataset = to_hf_dataset(data_dir, video)
info = {

14
lerobot/configs/env/aloha_thom.yaml vendored Normal file
View File

@ -0,0 +1,14 @@
# @package _global_
fps: 50
env:
name: aloha
task: AlohaInsertion-v0
from_pixels: True
pixels_only: False
image_size: [3, 480, 640]
episode_length: 500
fps: ${fps}
state_dim: 6
action_dim: 6

View File

@ -0,0 +1,77 @@
# @package _global_
seed: 1000
dataset_repo_id: lerobot/aloha_sim_insertion_human
training:
offline_steps: 20000
online_steps: 0
eval_freq: 100000
save_freq: 200
log_freq: 200
save_model: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
eval:
n_episodes: 50
batch_size: 50
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.front: mean_std
observation.state: mean_std
output_normalization_modes:
action: mean_std
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

View File

@ -23,6 +23,7 @@ import hydra
import torch
from omegaconf import DictConfig
from torch.cuda.amp import GradScaler
from tqdm import tqdm
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
@ -319,8 +320,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
policy.train()
is_offline = True
for step in range(cfg.training.offline_steps):
if step == 0:
for offline_step in tqdm(range(cfg.training.offline_steps)):
if offline_step == 0:
logging.info("Start offline training on a fixed dataset")
batch = next(dl_iter)
@ -338,12 +339,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
if offline_step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, offline_step, cfg, offline_dataset, is_offline)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1)
evaluate_and_checkpoint_if_needed(offline_step + 1)
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)