From dae901f5564512151ee6fde3801490e1ec889b6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 25 Apr 2024 10:45:10 +0200 Subject: [PATCH] move update outside policies --- lerobot/common/policies/act/modeling_act.py | 28 ------- .../policies/diffusion/modeling_diffusion.py | 36 --------- lerobot/scripts/train.py | 76 +++++++++++++++++-- 3 files changed, 71 insertions(+), 69 deletions(-) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 869ecd7b..c727988b 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -5,7 +5,6 @@ The majority of changes here involve removing unused code, unifying naming, and """ import math -import time from collections import deque from itertools import chain from typing import Callable @@ -206,33 +205,6 @@ class ActionChunkingTransformerPolicy(nn.Module): return loss_dict - def update(self, batch, **_) -> dict: - """Run the model in train mode, compute the loss, and do an optimization step.""" - start_time = time.time() - self.train() - - batch = self.normalize_inputs(batch) - - loss_dict = self.forward(batch) - # TODO(rcadene): self.unnormalize_outputs(out_dict) - loss = loss_dict["loss"] - loss.backward() - - grad_norm = torch.nn.utils.clip_grad_norm_( - self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False - ) - - self.optimizer.step() - self.optimizer.zero_grad() - - info = { - "loss": loss.item(), - "grad_norm": float(grad_norm), - "lr": self.cfg.lr, - "update_s": time.time() - start_time, - } - - return info def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Stacks all the images in a batch and puts them in a new key: "observation.images". diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 088b6cb6..4f488c01 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -11,7 +11,6 @@ TODO(alexander-soare): import copy import logging import math -import time from collections import deque from typing import Callable @@ -155,41 +154,6 @@ class DiffusionPolicy(nn.Module): loss = self.diffusion.compute_loss(batch) return {"loss": loss} - def update(self, batch: dict[str, Tensor], **_) -> dict: - """Run the model in train mode, compute the loss, and do an optimization step.""" - start_time = time.time() - - self.diffusion.train() - - batch = self.normalize_inputs(batch) - - loss = self.forward(batch)["loss"] - loss.backward() - - # TODO(rcadene): self.unnormalize_outputs(out_dict) - - grad_norm = torch.nn.utils.clip_grad_norm_( - self.diffusion.parameters(), - self.cfg.grad_clip_norm, - error_if_nonfinite=False, - ) - - self.optimizer.step() - self.optimizer.zero_grad() - self.lr_scheduler.step() - - if self.ema is not None: - self.ema.step(self.diffusion) - - info = { - "loss": loss.item(), - "grad_norm": float(grad_norm), - "lr": self.lr_scheduler.get_last_lr()[0], - "update_s": time.time() - start_time, - } - - return info - def save(self, fp): torch.save(self.state_dict(), fp) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index c849cce8..4ff29573 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -1,4 +1,5 @@ import logging +import time from copy import deepcopy from pathlib import Path @@ -7,6 +8,7 @@ import hydra import torch from datasets import concatenate_datasets from datasets.utils import disable_progress_bars, enable_progress_bars +from torch import Tensor from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle @@ -14,14 +16,73 @@ from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy from lerobot.common.utils.utils import ( - format_big_number, - get_safe_torch_device, - init_logging, - set_global_seed, + format_big_number, + get_safe_torch_device, + init_logging, + set_global_seed, ) from lerobot.scripts.eval import eval_policy +def update_diffusion(self, policy, batch: dict[str, Tensor], **_) -> dict: + """Run the model in train mode, compute the loss, and do an optimization step.""" + start_time = time.time() + policy.diffusion.train() + batch = policy.normalize_inputs(batch) + loss = policy.forward(batch)["loss"] + loss.backward() + + # TODO(rcadene): self.unnormalize_outputs(out_dict) + + grad_norm = torch.nn.utils.clip_grad_norm_( + policy.diffusion.parameters(), + policy.cfg.grad_clip_norm, + error_if_nonfinite=False, + ) + + policy.optimizer.step() + policy.optimizer.zero_grad() + policy.lr_scheduler.step() + + if policy.ema is not None: + policy.ema.step(policy.diffusion) + + info = { + "loss": loss.item(), + "grad_norm": float(grad_norm), + "lr": policy.lr_scheduler.get_last_lr()[0], + "update_s": time.time() - start_time, + } + + return info + + + +def update_act(self, policy, batch: dict[str, Tensor], **_) -> dict: + start_time = time.time() + policy.train() + batch = policy.normalize_inputs(batch) + loss_dict = policy.forward(batch) + # TODO(rcadene): policy.unnormalize_outputs(out_dict) + loss = loss_dict["loss"] + loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + policy.parameters(), policy.cfg.grad_clip_norm, error_if_nonfinite=False + ) + + policy.optimizer.step() + policy.optimizer.zero_grad() + + train_info = { + "loss": loss.item(), + "grad_norm": float(grad_norm), + "lr": policy.cfg.lr, + "update_s": time.time() - start_time, + } + + return train_info + @hydra.main(version_base=None, config_name="default", config_path="../configs") def train_cli(cfg: dict): train( @@ -293,7 +354,12 @@ def train(cfg: dict, out_dir=None, job_name=None): for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - train_info = policy.update(batch, step=step) + # Temporary hack to move update outside of policy + if isinstance(policy, DiffusionPolicy): + train_info = update_diffusion(policy, batch) + elif isinstance(policy, ActPolicy): + train_info = update_act(policy, batch) + # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: