move update outside policies

This commit is contained in:
Quentin Gallouédec 2024-04-25 10:45:10 +02:00
parent 7626b9a4a3
commit dae901f556
3 changed files with 71 additions and 69 deletions

View File

@ -5,7 +5,6 @@ The majority of changes here involve removing unused code, unifying naming, and
"""
import math
import time
from collections import deque
from itertools import chain
from typing import Callable
@ -206,33 +205,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
return loss_dict
def update(self, batch, **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.train()
batch = self.normalize_inputs(batch)
loss_dict = self.forward(batch)
# TODO(rcadene): self.unnormalize_outputs(out_dict)
loss = loss_dict["loss"]
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
)
self.optimizer.step()
self.optimizer.zero_grad()
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": self.cfg.lr,
"update_s": time.time() - start_time,
}
return info
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Stacks all the images in a batch and puts them in a new key: "observation.images".

View File

@ -11,7 +11,6 @@ TODO(alexander-soare):
import copy
import logging
import math
import time
from collections import deque
from typing import Callable
@ -155,41 +154,6 @@ class DiffusionPolicy(nn.Module):
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
def update(self, batch: dict[str, Tensor], **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.diffusion.train()
batch = self.normalize_inputs(batch)
loss = self.forward(batch)["loss"]
loss.backward()
# TODO(rcadene): self.unnormalize_outputs(out_dict)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
self.optimizer.step()
self.optimizer.zero_grad()
self.lr_scheduler.step()
if self.ema is not None:
self.ema.step(self.diffusion)
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": self.lr_scheduler.get_last_lr()[0],
"update_s": time.time() - start_time,
}
return info
def save(self, fp):
torch.save(self.state_dict(), fp)

View File

@ -1,4 +1,5 @@
import logging
import time
from copy import deepcopy
from pathlib import Path
@ -7,6 +8,7 @@ import hydra
import torch
from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars
from torch import Tensor
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
@ -22,6 +24,65 @@ from lerobot.common.utils.utils import (
from lerobot.scripts.eval import eval_policy
def update_diffusion(self, policy, batch: dict[str, Tensor], **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
policy.diffusion.train()
batch = policy.normalize_inputs(batch)
loss = policy.forward(batch)["loss"]
loss.backward()
# TODO(rcadene): self.unnormalize_outputs(out_dict)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.diffusion.parameters(),
policy.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
policy.optimizer.step()
policy.optimizer.zero_grad()
policy.lr_scheduler.step()
if policy.ema is not None:
policy.ema.step(policy.diffusion)
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": policy.lr_scheduler.get_last_lr()[0],
"update_s": time.time() - start_time,
}
return info
def update_act(self, policy, batch: dict[str, Tensor], **_) -> dict:
start_time = time.time()
policy.train()
batch = policy.normalize_inputs(batch)
loss_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = loss_dict["loss"]
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(), policy.cfg.grad_clip_norm, error_if_nonfinite=False
)
policy.optimizer.step()
policy.optimizer.zero_grad()
train_info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": policy.cfg.lr,
"update_s": time.time() - start_time,
}
return train_info
@hydra.main(version_base=None, config_name="default", config_path="../configs")
def train_cli(cfg: dict):
train(
@ -293,7 +354,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy.update(batch, step=step)
# Temporary hack to move update outside of policy
if isinstance(policy, DiffusionPolicy):
train_info = update_diffusion(policy, batch)
elif isinstance(policy, ActPolicy):
train_info = update_act(policy, batch)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0: