Merge branch 'main' into tutorial_act_pusht

This commit is contained in:
Alexander Soare 2024-05-20 20:44:03 +01:00 committed by GitHub
commit 69a93329d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 112 additions and 193 deletions

View File

@ -20,6 +20,8 @@ build-gpu:
test-end-to-end: test-end-to-end:
${MAKE} test-act-ete-train ${MAKE} test-act-ete-train
${MAKE} test-act-ete-eval ${MAKE} test-act-ete-eval
${MAKE} test-act-ete-train-amp
${MAKE} test-act-ete-eval-amp
${MAKE} test-diffusion-ete-train ${MAKE} test-diffusion-ete-train
${MAKE} test-diffusion-ete-eval ${MAKE} test-diffusion-ete-eval
${MAKE} test-tdmpc-ete-train ${MAKE} test-tdmpc-ete-train
@ -30,6 +32,7 @@ test-end-to-end:
test-act-ete-train: test-act-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=act \ policy=act \
policy.dim_model=64 \
env=aloha \ env=aloha \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=2 \ training.offline_steps=2 \
@ -52,9 +55,40 @@ test-act-ete-eval:
env.episode_length=8 \ env.episode_length=8 \
device=cpu \ device=cpu \
test-act-ete-train-amp:
python lerobot/scripts/train.py \
policy=act \
policy.dim_model=64 \
env=aloha \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
device=cpu \
training.save_model=true \
training.save_freq=2 \
policy.n_action_steps=20 \
policy.chunk_size=20 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/act/ \
use_amp=true
test-act-ete-eval-amp:
python lerobot/scripts/eval.py \
-p tests/outputs/act/checkpoints/000002 \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=cpu \
use_amp=true
test-diffusion-ete-train: test-diffusion-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=diffusion \ policy=diffusion \
policy.down_dims=\[64,128,256\] \
policy.diffusion_step_embed_dim=32 \
policy.num_inference_steps=10 \
env=pusht \ env=pusht \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=2 \ training.offline_steps=2 \
@ -75,6 +109,7 @@ test-diffusion-ete-eval:
env.episode_length=8 \ env.episode_length=8 \
device=cpu \ device=cpu \
# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated.
test-tdmpc-ete-train: test-tdmpc-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=tdmpc \ policy=tdmpc \
@ -83,7 +118,7 @@ test-tdmpc-ete-train:
dataset_repo_id=lerobot/xarm_lift_medium \ dataset_repo_id=lerobot/xarm_lift_medium \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=2 \ training.offline_steps=2 \
training.online_steps=2 \ training.online_steps=0 \
eval.n_episodes=1 \ eval.n_episodes=1 \
eval.batch_size=1 \ eval.batch_size=1 \
env.episode_length=2 \ env.episode_length=2 \
@ -101,7 +136,6 @@ test-tdmpc-ete-eval:
env.episode_length=8 \ env.episode_length=8 \
device=cpu \ device=cpu \
test-default-ete-eval: test-default-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \ --config lerobot/configs/default.yaml \

View File

@ -10,6 +10,9 @@ hydra:
name: default name: default
device: cuda # cpu device: cuda # cpu
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: false
# `seed` is used for training (eg: model initialization, dataset shuffling) # `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments. # AND for the evaluation environments.
seed: ??? seed: ???
@ -17,6 +20,7 @@ dataset_repo_id: lerobot/pusht
training: training:
offline_steps: ??? offline_steps: ???
# NOTE: `online_steps` is not implemented yet. It's here as a placeholder.
online_steps: ??? online_steps: ???
online_steps_between_rollouts: ??? online_steps_between_rollouts: ???
online_sampling_ratio: 0.5 online_sampling_ratio: 0.5

View File

@ -5,7 +5,8 @@ dataset_repo_id: lerobot/xarm_lift_medium
training: training:
offline_steps: 25000 offline_steps: 25000
online_steps: 25000 # TODO(alexander-soare): uncomment when online training gets reinstated
online_steps: 0 # 25000 not implemented yet
eval_freq: 5000 eval_freq: 5000
online_steps_between_rollouts: 1 online_steps_between_rollouts: 1
online_sampling_ratio: 0.5 online_sampling_ratio: 0.5

View File

@ -46,6 +46,7 @@ import json
import logging import logging
import threading import threading
import time import time
from contextlib import nullcontext
from copy import deepcopy from copy import deepcopy
from datetime import datetime as dt from datetime import datetime as dt
from pathlib import Path from pathlib import Path
@ -520,7 +521,7 @@ def eval(
raise NotImplementedError() raise NotImplementedError()
# Check device is available # Check device is available
get_safe_torch_device(hydra_cfg.device, log=True) device = get_safe_torch_device(hydra_cfg.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -539,16 +540,17 @@ def eval(
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats) policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
policy.eval() policy.eval()
info = eval_policy( with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
env, info = eval_policy(
policy, env,
hydra_cfg.eval.n_episodes, policy,
max_episodes_rendered=10, hydra_cfg.eval.n_episodes,
video_dir=Path(out_dir) / "eval", max_episodes_rendered=10,
start_seed=hydra_cfg.seed, video_dir=Path(out_dir) / "eval",
enable_progbar=True, start_seed=hydra_cfg.seed,
enable_inner_progbar=True, enable_progbar=True,
) enable_inner_progbar=True,
)
print(info["aggregated"]) print(info["aggregated"])
# Save info # Save info

View File

@ -15,15 +15,14 @@
# limitations under the License. # limitations under the License.
import logging import logging
import time import time
from contextlib import nullcontext
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import datasets
import hydra import hydra
import torch import torch
from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars
from omegaconf import DictConfig from omegaconf import DictConfig
from torch.cuda.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
@ -31,6 +30,7 @@ from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import PolicyWithUpdate from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
get_safe_torch_device, get_safe_torch_device,
@ -69,7 +69,6 @@ def make_optimizer_and_scheduler(cfg, policy):
cfg.training.adam_eps, cfg.training.adam_eps,
cfg.training.adam_weight_decay, cfg.training.adam_weight_decay,
) )
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
@ -87,21 +86,40 @@ def make_optimizer_and_scheduler(cfg, policy):
return optimizer, lr_scheduler return optimizer, lr_scheduler
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): def update_policy(
policy,
batch,
optimizer,
grad_clip_norm,
grad_scaler: GradScaler,
lr_scheduler=None,
use_amp: bool = False,
):
"""Returns a dictionary of items for logging.""" """Returns a dictionary of items for logging."""
start_time = time.time() start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train() policy.train()
output_dict = policy.forward(batch) with torch.autocast(device_type=device.type) if use_amp else nullcontext():
# TODO(rcadene): policy.unnormalize_outputs(out_dict) output_dict = policy.forward(batch)
loss = output_dict["loss"] # TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss.backward() loss = output_dict["loss"]
grad_scaler.scale(loss).backward()
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(), policy.parameters(),
grad_clip_norm, grad_clip_norm,
error_if_nonfinite=False, error_if_nonfinite=False,
) )
optimizer.step() # Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if lr_scheduler is not None: if lr_scheduler is not None:
@ -115,7 +133,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
"loss": loss.item(), "loss": loss.item(),
"grad_norm": float(grad_norm), "grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"], "lr": optimizer.param_groups[0]["lr"],
"update_s": time.time() - start_time, "update_s": time.perf_counter() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"}, **{k: v for k, v in output_dict.items() if k != "loss"},
} }
@ -211,103 +229,6 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
logger.log_dict(info, step, mode="eval") logger.log_dict(info, step, mode="eval")
def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
"""
Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average).
Parameters:
- n_off (int): Number of offline samples, each with a sampling weight of 1.
- n_on (int): Number of online samples.
- pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5).
The total weight of offline samples is n_off * 1.0.
The total weight of offline samples is n_on * w.
The total combined weight of all samples is n_off + n_on * w.
The fraction of the weight that is online is n_on * w / (n_off + n_on * w).
We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on.
The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1))
"""
assert 0.0 <= pc_on <= 1.0
return -(n_off * pc_on) / (n_on * (pc_on - 1))
def add_episodes_inplace(
online_dataset: torch.utils.data.Dataset,
concat_dataset: torch.utils.data.ConcatDataset,
sampler: torch.utils.data.WeightedRandomSampler,
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
pc_online_samples: float,
):
"""
Modifies the online_dataset, concat_dataset, and sampler in place by integrating
new episodes from hf_dataset into the online_dataset, updating the concatenated
dataset's structure and adjusting the sampling strategy based on the specified
percentage of online samples.
Parameters:
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
offline and online datasets, used for sampling purposes.
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
reflect changes in the dataset sizes and specified sampling weights.
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- pc_online_samples (float): The target percentage of samples that should come from
the online dataset during sampling operations.
Raises:
- AssertionError: If the first episode_id or index in hf_dataset is not 0
"""
first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item()
first_index = hf_dataset.select_columns("index")[0]["index"].item()
last_index = hf_dataset.select_columns("index")[-1]["index"].item()
# sanity check
assert first_episode_idx == 0, f"{first_episode_idx=} is not 0"
assert first_index == 0, f"{first_index=} is not 0"
assert first_index == episode_data_index["from"][first_episode_idx].item()
assert last_index == episode_data_index["to"][last_episode_idx].item() - 1
if len(online_dataset) == 0:
# initialize online dataset
online_dataset.hf_dataset = hf_dataset
online_dataset.episode_data_index = episode_data_index
else:
# get the starting indices of the new episodes and frames to be added
start_episode_idx = last_episode_idx + 1
start_index = last_index + 1
def shift_indices(episode_index, index):
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index}
return example
disable_progress_bars() # map has a tqdm progress bar
hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"])
enable_progress_bars()
episode_data_index["from"] += start_index
episode_data_index["to"] += start_index
# extend online dataset
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
# update the concatenated dataset length used during sampling
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
len_online = len(online_dataset)
len_offline = len(concat_dataset) - len_online
weight_offline = 1.0
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
# update the total number of samples used during sampling
sampler.num_samples = len(concat_dataset)
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
@ -316,11 +237,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
init_logging() init_logging()
if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1: if cfg.training.online_steps > 0:
logging.warning("eval.batch_size > 1 not supported for online training steps") raise NotImplementedError("Online training is not implemented yet.")
# Check device is available # Check device is available
get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -338,6 +259,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Create optimizer and scheduler # Create optimizer and scheduler
# Temporary hack to move optimizer out of policy # Temporary hack to move optimizer out of policy
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(enabled=cfg.use_amp)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
@ -358,14 +280,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
def evaluate_and_checkpoint_if_needed(step): def evaluate_and_checkpoint_if_needed(step):
if step % cfg.training.eval_freq == 0: if step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
eval_info = eval_policy( with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
eval_env, eval_info = eval_policy(
policy, eval_env,
cfg.eval.n_episodes, policy,
video_dir=Path(out_dir) / "eval", cfg.eval.n_episodes,
max_episodes_rendered=4, video_dir=Path(out_dir) / "eval",
start_seed=cfg.seed, max_episodes_rendered=4,
) start_seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
if cfg.wandb.enable: if cfg.wandb.enable:
logger.log_video(eval_info["video_paths"][0], step, mode="eval") logger.log_video(eval_info["video_paths"][0], step, mode="eval")
@ -389,23 +312,30 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
num_workers=4, num_workers=4,
batch_size=cfg.training.batch_size, batch_size=cfg.training.batch_size,
shuffle=True, shuffle=True,
pin_memory=cfg.device != "cpu", pin_memory=device.type != "cpu",
drop_last=False, drop_last=False,
) )
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
policy.train() policy.train()
step = 0 # number of policy update (forward + backward + optim)
is_offline = True is_offline = True
for offline_step in range(cfg.training.offline_steps): for step in range(cfg.training.offline_steps):
if offline_step == 0: if step == 0:
logging.info("Start offline training on a fixed dataset") logging.info("Start offline training on a fixed dataset")
batch = next(dl_iter) batch = next(dl_iter)
for key in batch: for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True) batch[key] = batch[key].to(device, non_blocking=True)
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler) train_info = update_policy(
policy,
batch,
optimizer,
cfg.training.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.training.log_freq == 0: if step % cfg.training.log_freq == 0:
@ -415,11 +345,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# so we pass in step + 1. # so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1) evaluate_and_checkpoint_if_needed(step + 1)
step += 1
# create an env dedicated to online episodes collection from policy rollout
online_training_env = make_env(cfg, n_envs=1)
# create an empty online dataset similar to offline dataset # create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset) online_dataset = deepcopy(offline_dataset)
online_dataset.hf_dataset = {} online_dataset.hf_dataset = {}
@ -436,58 +361,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
num_workers=4, num_workers=4,
batch_size=cfg.training.batch_size, batch_size=cfg.training.batch_size,
sampler=sampler, sampler=sampler,
pin_memory=cfg.device != "cpu", pin_memory=device.type != "cpu",
drop_last=False, drop_last=False,
) )
dl_iter = cycle(dataloader)
online_step = 0
is_offline = False
for env_step in range(cfg.training.online_steps):
if env_step == 0:
logging.info("Start online training by interacting with environment")
policy.eval()
with torch.no_grad():
eval_info = eval_policy(
online_training_env,
policy,
n_episodes=1,
return_episode_data=True,
start_seed=cfg.training.online_env_seed,
enable_progbar=True,
)
add_episodes_inplace(
online_dataset,
concat_dataset,
sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"],
pc_online_samples=cfg.training.online_sampling_ratio,
)
policy.train()
for _ in range(cfg.training.online_steps_between_rollouts):
batch = next(dl_iter)
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1)
step += 1
online_step += 1
eval_env.close() eval_env.close()
online_training_env.close()
logging.info("End of training") logging.info("End of training")