move update outside policies

This commit is contained in:
Quentin Gallouédec 2024-04-25 10:45:10 +02:00
parent 7626b9a4a3
commit dae901f556
3 changed files with 71 additions and 69 deletions

View File

@ -5,7 +5,6 @@ The majority of changes here involve removing unused code, unifying naming, and
""" """
import math import math
import time
from collections import deque from collections import deque
from itertools import chain from itertools import chain
from typing import Callable from typing import Callable
@ -206,33 +205,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
return loss_dict return loss_dict
def update(self, batch, **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.train()
batch = self.normalize_inputs(batch)
loss_dict = self.forward(batch)
# TODO(rcadene): self.unnormalize_outputs(out_dict)
loss = loss_dict["loss"]
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
)
self.optimizer.step()
self.optimizer.zero_grad()
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": self.cfg.lr,
"update_s": time.time() - start_time,
}
return info
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Stacks all the images in a batch and puts them in a new key: "observation.images". """Stacks all the images in a batch and puts them in a new key: "observation.images".

View File

@ -11,7 +11,6 @@ TODO(alexander-soare):
import copy import copy
import logging import logging
import math import math
import time
from collections import deque from collections import deque
from typing import Callable from typing import Callable
@ -155,41 +154,6 @@ class DiffusionPolicy(nn.Module):
loss = self.diffusion.compute_loss(batch) loss = self.diffusion.compute_loss(batch)
return {"loss": loss} return {"loss": loss}
def update(self, batch: dict[str, Tensor], **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.diffusion.train()
batch = self.normalize_inputs(batch)
loss = self.forward(batch)["loss"]
loss.backward()
# TODO(rcadene): self.unnormalize_outputs(out_dict)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
self.optimizer.step()
self.optimizer.zero_grad()
self.lr_scheduler.step()
if self.ema is not None:
self.ema.step(self.diffusion)
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": self.lr_scheduler.get_last_lr()[0],
"update_s": time.time() - start_time,
}
return info
def save(self, fp): def save(self, fp):
torch.save(self.state_dict(), fp) torch.save(self.state_dict(), fp)

View File

@ -1,4 +1,5 @@
import logging import logging
import time
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
@ -7,6 +8,7 @@ import hydra
import torch import torch
from datasets import concatenate_datasets from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars from datasets.utils import disable_progress_bars, enable_progress_bars
from torch import Tensor
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
@ -14,14 +16,73 @@ from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
get_safe_torch_device, get_safe_torch_device,
init_logging, init_logging,
set_global_seed, set_global_seed,
) )
from lerobot.scripts.eval import eval_policy from lerobot.scripts.eval import eval_policy
def update_diffusion(self, policy, batch: dict[str, Tensor], **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
policy.diffusion.train()
batch = policy.normalize_inputs(batch)
loss = policy.forward(batch)["loss"]
loss.backward()
# TODO(rcadene): self.unnormalize_outputs(out_dict)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.diffusion.parameters(),
policy.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
policy.optimizer.step()
policy.optimizer.zero_grad()
policy.lr_scheduler.step()
if policy.ema is not None:
policy.ema.step(policy.diffusion)
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": policy.lr_scheduler.get_last_lr()[0],
"update_s": time.time() - start_time,
}
return info
def update_act(self, policy, batch: dict[str, Tensor], **_) -> dict:
start_time = time.time()
policy.train()
batch = policy.normalize_inputs(batch)
loss_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = loss_dict["loss"]
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(), policy.cfg.grad_clip_norm, error_if_nonfinite=False
)
policy.optimizer.step()
policy.optimizer.zero_grad()
train_info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": policy.cfg.lr,
"update_s": time.time() - start_time,
}
return train_info
@hydra.main(version_base=None, config_name="default", config_path="../configs") @hydra.main(version_base=None, config_name="default", config_path="../configs")
def train_cli(cfg: dict): def train_cli(cfg: dict):
train( train(
@ -293,7 +354,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch: for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True) batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy.update(batch, step=step) # Temporary hack to move update outside of policy
if isinstance(policy, DiffusionPolicy):
train_info = update_diffusion(policy, batch)
elif isinstance(policy, ActPolicy):
train_info = update_act(policy, batch)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0: if step % cfg.log_freq == 0: