save
This commit is contained in:
parent
b6c216b590
commit
97cb7a2362
|
@ -43,8 +43,7 @@ def get_cameras(hdf5_data):
|
||||||
|
|
||||||
|
|
||||||
def check_format(raw_dir) -> bool:
|
def check_format(raw_dir) -> bool:
|
||||||
# only frames from simulation are uncompressed
|
compressed_images = None
|
||||||
compressed_images = "sim" not in raw_dir.name
|
|
||||||
|
|
||||||
hdf5_paths = list(raw_dir.glob("episode_*.hdf5"))
|
hdf5_paths = list(raw_dir.glob("episode_*.hdf5"))
|
||||||
assert len(hdf5_paths) != 0
|
assert len(hdf5_paths) != 0
|
||||||
|
@ -62,18 +61,20 @@ def check_format(raw_dir) -> bool:
|
||||||
for camera in get_cameras(data):
|
for camera in get_cameras(data):
|
||||||
assert num_frames == data[f"/observations/images/{camera}"].shape[0]
|
assert num_frames == data[f"/observations/images/{camera}"].shape[0]
|
||||||
|
|
||||||
if compressed_images:
|
assert data[f"/observations/images/{camera}"].ndim in [2, 4]
|
||||||
assert data[f"/observations/images/{camera}"].ndim == 2
|
if data[f"/observations/images/{camera}"].ndim == 2:
|
||||||
|
assert compressed_images is None or compressed_images
|
||||||
|
compressed_images = True
|
||||||
else:
|
else:
|
||||||
|
assert compressed_images is None or not compressed_images
|
||||||
|
compressed_images = False
|
||||||
assert data[f"/observations/images/{camera}"].ndim == 4
|
assert data[f"/observations/images/{camera}"].ndim == 4
|
||||||
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||||
|
return compressed_images
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
def load_from_raw(raw_dir, out_dir, fps, video, debug, compressed_images):
|
||||||
# only frames from simulation are uncompressed
|
|
||||||
compressed_images = "sim" not in raw_dir.name
|
|
||||||
|
|
||||||
hdf5_files = list(raw_dir.glob("*.hdf5"))
|
hdf5_files = list(raw_dir.glob("*.hdf5"))
|
||||||
ep_dicts = []
|
ep_dicts = []
|
||||||
episode_data_index = {"from": [], "to": []}
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
@ -199,12 +200,12 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||||
|
|
||||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||||
# sanity check
|
# sanity check
|
||||||
check_format(raw_dir)
|
compressed_images = check_format(raw_dir)
|
||||||
|
|
||||||
if fps is None:
|
if fps is None:
|
||||||
fps = 50
|
fps = 50
|
||||||
|
|
||||||
data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug, compressed_images)
|
||||||
hf_dataset = to_hf_dataset(data_dir, video)
|
hf_dataset = to_hf_dataset(data_dir, video)
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
fps: 50
|
||||||
|
|
||||||
|
env:
|
||||||
|
name: aloha
|
||||||
|
task: AlohaInsertion-v0
|
||||||
|
from_pixels: True
|
||||||
|
pixels_only: False
|
||||||
|
image_size: [3, 480, 640]
|
||||||
|
episode_length: 500
|
||||||
|
fps: ${fps}
|
||||||
|
state_dim: 6
|
||||||
|
action_dim: 6
|
|
@ -0,0 +1,77 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
seed: 1000
|
||||||
|
dataset_repo_id: lerobot/aloha_sim_insertion_human
|
||||||
|
|
||||||
|
training:
|
||||||
|
offline_steps: 20000
|
||||||
|
online_steps: 0
|
||||||
|
eval_freq: 100000
|
||||||
|
save_freq: 200
|
||||||
|
log_freq: 200
|
||||||
|
save_model: true
|
||||||
|
|
||||||
|
batch_size: 8
|
||||||
|
lr: 1e-5
|
||||||
|
lr_backbone: 1e-5
|
||||||
|
weight_decay: 1e-4
|
||||||
|
grad_clip_norm: 10
|
||||||
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||||
|
|
||||||
|
eval:
|
||||||
|
n_episodes: 50
|
||||||
|
batch_size: 50
|
||||||
|
|
||||||
|
# See `configuration_act.py` for more details.
|
||||||
|
policy:
|
||||||
|
name: act
|
||||||
|
|
||||||
|
# Input / output structure.
|
||||||
|
n_obs_steps: 1
|
||||||
|
chunk_size: 100 # chunk_size
|
||||||
|
n_action_steps: 100
|
||||||
|
|
||||||
|
input_shapes:
|
||||||
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
observation.images: [3, 480, 640]
|
||||||
|
observation.state: ["${env.state_dim}"]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes:
|
||||||
|
observation.images.front: mean_std
|
||||||
|
observation.state: mean_std
|
||||||
|
output_normalization_modes:
|
||||||
|
action: mean_std
|
||||||
|
|
||||||
|
# Architecture.
|
||||||
|
# Vision backbone.
|
||||||
|
vision_backbone: resnet18
|
||||||
|
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||||
|
replace_final_stride_with_dilation: false
|
||||||
|
# Transformer layers.
|
||||||
|
pre_norm: false
|
||||||
|
dim_model: 512
|
||||||
|
n_heads: 8
|
||||||
|
dim_feedforward: 3200
|
||||||
|
feedforward_activation: relu
|
||||||
|
n_encoder_layers: 4
|
||||||
|
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||||
|
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||||
|
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||||
|
n_decoder_layers: 1
|
||||||
|
# VAE.
|
||||||
|
use_vae: true
|
||||||
|
latent_dim: 32
|
||||||
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
|
# Inference.
|
||||||
|
temporal_ensemble_momentum: null
|
||||||
|
|
||||||
|
# Training and loss computation.
|
||||||
|
dropout: 0.1
|
||||||
|
kl_weight: 10.0
|
|
@ -23,6 +23,7 @@ import hydra
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
|
@ -319,8 +320,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
is_offline = True
|
is_offline = True
|
||||||
for step in range(cfg.training.offline_steps):
|
for offline_step in tqdm(range(cfg.training.offline_steps)):
|
||||||
if step == 0:
|
if offline_step == 0:
|
||||||
logging.info("Start offline training on a fixed dataset")
|
logging.info("Start offline training on a fixed dataset")
|
||||||
batch = next(dl_iter)
|
batch = next(dl_iter)
|
||||||
|
|
||||||
|
@ -338,12 +339,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||||
if step % cfg.training.log_freq == 0:
|
if offline_step % cfg.training.log_freq == 0:
|
||||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
|
log_train_info(logger, train_info, offline_step, cfg, offline_dataset, is_offline)
|
||||||
|
|
||||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||||
# so we pass in step + 1.
|
# so we pass in step + 1.
|
||||||
evaluate_and_checkpoint_if_needed(step + 1)
|
evaluate_and_checkpoint_if_needed(offline_step + 1)
|
||||||
|
|
||||||
# create an empty online dataset similar to offline dataset
|
# create an empty online dataset similar to offline dataset
|
||||||
online_dataset = deepcopy(offline_dataset)
|
online_dataset = deepcopy(offline_dataset)
|
||||||
|
|
Loading…
Reference in New Issue