move update outside policies
This commit is contained in:
parent
7626b9a4a3
commit
dae901f556
|
@ -5,7 +5,6 @@ The majority of changes here involve removing unused code, unifying naming, and
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import time
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
@ -206,33 +205,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
def update(self, batch, **_) -> dict:
|
|
||||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
|
||||||
start_time = time.time()
|
|
||||||
self.train()
|
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
|
||||||
|
|
||||||
loss_dict = self.forward(batch)
|
|
||||||
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
|
||||||
loss = loss_dict["loss"]
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
||||||
self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.optimizer.step()
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
info = {
|
|
||||||
"loss": loss.item(),
|
|
||||||
"grad_norm": float(grad_norm),
|
|
||||||
"lr": self.cfg.lr,
|
|
||||||
"update_s": time.time() - start_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
return info
|
|
||||||
|
|
||||||
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
||||||
|
|
|
@ -11,7 +11,6 @@ TODO(alexander-soare):
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import time
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
@ -155,41 +154,6 @@ class DiffusionPolicy(nn.Module):
|
||||||
loss = self.diffusion.compute_loss(batch)
|
loss = self.diffusion.compute_loss(batch)
|
||||||
return {"loss": loss}
|
return {"loss": loss}
|
||||||
|
|
||||||
def update(self, batch: dict[str, Tensor], **_) -> dict:
|
|
||||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
self.diffusion.train()
|
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
|
||||||
|
|
||||||
loss = self.forward(batch)["loss"]
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
||||||
self.diffusion.parameters(),
|
|
||||||
self.cfg.grad_clip_norm,
|
|
||||||
error_if_nonfinite=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.optimizer.step()
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
|
|
||||||
if self.ema is not None:
|
|
||||||
self.ema.step(self.diffusion)
|
|
||||||
|
|
||||||
info = {
|
|
||||||
"loss": loss.item(),
|
|
||||||
"grad_norm": float(grad_norm),
|
|
||||||
"lr": self.lr_scheduler.get_last_lr()[0],
|
|
||||||
"update_s": time.time() - start_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
return info
|
|
||||||
|
|
||||||
def save(self, fp):
|
def save(self, fp):
|
||||||
torch.save(self.state_dict(), fp)
|
torch.save(self.state_dict(), fp)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -7,6 +8,7 @@ import hydra
|
||||||
import torch
|
import torch
|
||||||
from datasets import concatenate_datasets
|
from datasets import concatenate_datasets
|
||||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
|
@ -22,6 +24,65 @@ from lerobot.common.utils.utils import (
|
||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
|
|
||||||
|
|
||||||
|
def update_diffusion(self, policy, batch: dict[str, Tensor], **_) -> dict:
|
||||||
|
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||||
|
start_time = time.time()
|
||||||
|
policy.diffusion.train()
|
||||||
|
batch = policy.normalize_inputs(batch)
|
||||||
|
loss = policy.forward(batch)["loss"]
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
||||||
|
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
policy.diffusion.parameters(),
|
||||||
|
policy.cfg.grad_clip_norm,
|
||||||
|
error_if_nonfinite=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
policy.optimizer.step()
|
||||||
|
policy.optimizer.zero_grad()
|
||||||
|
policy.lr_scheduler.step()
|
||||||
|
|
||||||
|
if policy.ema is not None:
|
||||||
|
policy.ema.step(policy.diffusion)
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"loss": loss.item(),
|
||||||
|
"grad_norm": float(grad_norm),
|
||||||
|
"lr": policy.lr_scheduler.get_last_lr()[0],
|
||||||
|
"update_s": time.time() - start_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def update_act(self, policy, batch: dict[str, Tensor], **_) -> dict:
|
||||||
|
start_time = time.time()
|
||||||
|
policy.train()
|
||||||
|
batch = policy.normalize_inputs(batch)
|
||||||
|
loss_dict = policy.forward(batch)
|
||||||
|
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||||
|
loss = loss_dict["loss"]
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
policy.parameters(), policy.cfg.grad_clip_norm, error_if_nonfinite=False
|
||||||
|
)
|
||||||
|
|
||||||
|
policy.optimizer.step()
|
||||||
|
policy.optimizer.zero_grad()
|
||||||
|
|
||||||
|
train_info = {
|
||||||
|
"loss": loss.item(),
|
||||||
|
"grad_norm": float(grad_norm),
|
||||||
|
"lr": policy.cfg.lr,
|
||||||
|
"update_s": time.time() - start_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
return train_info
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
||||||
def train_cli(cfg: dict):
|
def train_cli(cfg: dict):
|
||||||
train(
|
train(
|
||||||
|
@ -293,7 +354,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
for key in batch:
|
for key in batch:
|
||||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||||
|
|
||||||
train_info = policy.update(batch, step=step)
|
# Temporary hack to move update outside of policy
|
||||||
|
if isinstance(policy, DiffusionPolicy):
|
||||||
|
train_info = update_diffusion(policy, batch)
|
||||||
|
elif isinstance(policy, ActPolicy):
|
||||||
|
train_info = update_act(policy, batch)
|
||||||
|
|
||||||
|
|
||||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||||
if step % cfg.log_freq == 0:
|
if step % cfg.log_freq == 0:
|
||||||
|
|
Loading…
Reference in New Issue