Remove `update` method from the policy (#99)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
parent
5b4fd8891d
commit
508bd92d03
|
@ -6,6 +6,7 @@ data
|
||||||
outputs
|
outputs
|
||||||
.vscode
|
.vscode
|
||||||
rl
|
rl
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
# HPC
|
# HPC
|
||||||
nautilus/*.yaml
|
nautilus/*.yaml
|
||||||
|
|
4
Makefile
4
Makefile
|
@ -22,8 +22,8 @@ test-end-to-end:
|
||||||
${MAKE} test-act-ete-eval
|
${MAKE} test-act-ete-eval
|
||||||
${MAKE} test-diffusion-ete-train
|
${MAKE} test-diffusion-ete-train
|
||||||
${MAKE} test-diffusion-ete-eval
|
${MAKE} test-diffusion-ete-eval
|
||||||
${MAKE} test-tdmpc-ete-train
|
# ${MAKE} test-tdmpc-ete-train
|
||||||
${MAKE} test-tdmpc-ete-eval
|
# ${MAKE} test-tdmpc-ete-eval
|
||||||
|
|
||||||
test-act-ete-train:
|
test-act-ete-train:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
|
|
|
@ -38,6 +38,8 @@ policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, da
|
||||||
policy.train()
|
policy.train()
|
||||||
policy.to(device)
|
policy.to(device)
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay)
|
||||||
|
|
||||||
# Create dataloader for offline training.
|
# Create dataloader for offline training.
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
@ -54,9 +56,14 @@ done = False
|
||||||
while not done:
|
while not done:
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||||
info = policy.update(batch)
|
output_dict = policy.forward(batch)
|
||||||
|
loss = output_dict["loss"]
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if step % log_freq == 0:
|
if step % log_freq == 0:
|
||||||
print(f"step: {step} loss: {info['loss']:.3f} update_time: {info['update_s']:.3f} (seconds)")
|
print(f"step: {step} loss: {loss.item():.3f}")
|
||||||
step += 1
|
step += 1
|
||||||
if step >= training_steps:
|
if step >= training_steps:
|
||||||
done = True
|
done = True
|
||||||
|
|
|
@ -5,7 +5,6 @@ The majority of changes here involve removing unused code, unifying naming, and
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import time
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
@ -135,25 +134,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
self.action_head = nn.Linear(cfg.d_model, cfg.output_shapes["action"][0])
|
self.action_head = nn.Linear(cfg.d_model, cfg.output_shapes["action"][0])
|
||||||
|
|
||||||
self._reset_parameters()
|
self._reset_parameters()
|
||||||
self._create_optimizer()
|
|
||||||
|
|
||||||
def _create_optimizer(self):
|
|
||||||
optimizer_params_dicts = [
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p for n, p in self.named_parameters() if not n.startswith("backbone") and p.requires_grad
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p for n, p in self.named_parameters() if n.startswith("backbone") and p.requires_grad
|
|
||||||
],
|
|
||||||
"lr": self.cfg.lr_backbone,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
self.optimizer = torch.optim.AdamW(
|
|
||||||
optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay
|
|
||||||
)
|
|
||||||
|
|
||||||
def _reset_parameters(self):
|
def _reset_parameters(self):
|
||||||
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
|
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
|
||||||
|
@ -191,6 +171,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
|
||||||
def forward(self, batch, **_) -> dict[str, Tensor]:
|
def forward(self, batch, **_) -> dict[str, Tensor]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch = self.normalize_targets(batch)
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
|
||||||
|
|
||||||
l1_loss = (
|
l1_loss = (
|
||||||
|
@ -213,34 +195,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
|
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
def update(self, batch, **_) -> dict:
|
|
||||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
|
||||||
start_time = time.time()
|
|
||||||
self.train()
|
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
|
||||||
batch = self.normalize_targets(batch)
|
|
||||||
|
|
||||||
loss_dict = self.forward(batch)
|
|
||||||
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
|
||||||
loss = loss_dict["loss"]
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
||||||
self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.optimizer.step()
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
info = {
|
|
||||||
"loss": loss.item(),
|
|
||||||
"grad_norm": float(grad_norm),
|
|
||||||
"lr": self.cfg.lr,
|
|
||||||
"update_s": time.time() - start_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
return info
|
|
||||||
|
|
||||||
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
||||||
|
|
|
@ -11,7 +11,6 @@ TODO(alexander-soare):
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import time
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
@ -19,7 +18,6 @@ import einops
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
import torchvision
|
import torchvision
|
||||||
from diffusers.optimization import get_scheduler
|
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
from robomimic.models.base_nets import SpatialSoftmax
|
from robomimic.models.base_nets import SpatialSoftmax
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
@ -74,26 +72,6 @@ class DiffusionPolicy(nn.Module):
|
||||||
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
||||||
self.ema = _EMA(cfg, model=self.ema_diffusion)
|
self.ema = _EMA(cfg, model=self.ema_diffusion)
|
||||||
|
|
||||||
# TODO(alexander-soare): Move optimizer out of policy.
|
|
||||||
self.optimizer = torch.optim.Adam(
|
|
||||||
self.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO(alexander-soare): Move LR scheduler out of policy.
|
|
||||||
# TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps
|
|
||||||
self.global_step = 0
|
|
||||||
|
|
||||||
# configure lr scheduler
|
|
||||||
self.lr_scheduler = get_scheduler(
|
|
||||||
cfg.lr_scheduler,
|
|
||||||
optimizer=self.optimizer,
|
|
||||||
num_warmup_steps=cfg.lr_warmup_steps,
|
|
||||||
num_training_steps=lr_scheduler_num_training_steps,
|
|
||||||
# pytorch assumes stepping LRScheduler every epoch
|
|
||||||
# however huggingface diffusers steps it every batch
|
|
||||||
last_epoch=self.global_step - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
Clear observation and action queues. Should be called on `env.reset()`
|
Clear observation and action queues. Should be called on `env.reset()`
|
||||||
|
@ -155,44 +133,10 @@ class DiffusionPolicy(nn.Module):
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
loss = self.diffusion.compute_loss(batch)
|
|
||||||
return {"loss": loss}
|
|
||||||
|
|
||||||
def update(self, batch: dict[str, Tensor], **_) -> dict:
|
|
||||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
self.diffusion.train()
|
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
|
loss = self.diffusion.compute_loss(batch)
|
||||||
loss = self.forward(batch)["loss"]
|
return {"loss": loss}
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
||||||
self.diffusion.parameters(),
|
|
||||||
self.cfg.grad_clip_norm,
|
|
||||||
error_if_nonfinite=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.optimizer.step()
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
|
|
||||||
if self.ema is not None:
|
|
||||||
self.ema.step(self.diffusion)
|
|
||||||
|
|
||||||
info = {
|
|
||||||
"loss": loss.item(),
|
|
||||||
"grad_norm": float(grad_norm),
|
|
||||||
"lr": self.lr_scheduler.get_last_lr()[0],
|
|
||||||
"update_s": time.time() - start_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
return info
|
|
||||||
|
|
||||||
def save(self, fp):
|
def save(self, fp):
|
||||||
torch.save(self.state_dict(), fp)
|
torch.save(self.state_dict(), fp)
|
||||||
|
|
|
@ -36,10 +36,3 @@ class Policy(Protocol):
|
||||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||||
with caching.
|
with caching.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def update(self, batch):
|
|
||||||
"""Does compute_loss then an optimization step.
|
|
||||||
|
|
||||||
TODO(alexander-soare): We will move the optimization step back into the training loop, so this will
|
|
||||||
disappear.
|
|
||||||
"""
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -7,6 +8,7 @@ import hydra
|
||||||
import torch
|
import torch
|
||||||
from datasets import concatenate_datasets
|
from datasets import concatenate_datasets
|
||||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
|
@ -22,6 +24,37 @@ from lerobot.common.utils.utils import (
|
||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
|
|
||||||
|
|
||||||
|
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||||
|
start_time = time.time()
|
||||||
|
policy.train()
|
||||||
|
output_dict = policy.forward(batch)
|
||||||
|
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||||
|
loss = output_dict["loss"]
|
||||||
|
loss.backward()
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
policy.parameters(),
|
||||||
|
grad_clip_norm,
|
||||||
|
error_if_nonfinite=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
if lr_scheduler is not None:
|
||||||
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
if hasattr(policy, "ema") and policy.ema is not None:
|
||||||
|
policy.ema.step(policy.diffusion)
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"loss": loss.item(),
|
||||||
|
"grad_norm": float(grad_norm),
|
||||||
|
"lr": optimizer.param_groups[0]['lr'],
|
||||||
|
"update_s": time.time() - start_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||||
def train_cli(cfg: dict):
|
def train_cli(cfg: dict):
|
||||||
train(
|
train(
|
||||||
|
@ -234,6 +267,36 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
logging.info("make_policy")
|
logging.info("make_policy")
|
||||||
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
|
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
|
||||||
|
|
||||||
|
# Create optimizer and scheduler
|
||||||
|
# Temporary hack to move optimizer out of policy
|
||||||
|
if cfg.policy.name == "act":
|
||||||
|
optimizer_params_dicts = [
|
||||||
|
{"params": [p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad]},
|
||||||
|
{
|
||||||
|
"params": [p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad],
|
||||||
|
"lr": cfg.policy.lr_backbone,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=cfg.policy.lr, weight_decay=cfg.policy.weight_decay)
|
||||||
|
lr_scheduler = None
|
||||||
|
elif cfg.policy.name == "diffusion":
|
||||||
|
optimizer = torch.optim.Adam(
|
||||||
|
policy.diffusion.parameters(), cfg.policy.lr, cfg.policy.adam_betas, cfg.policy.adam_eps, cfg.policy.adam_weight_decay
|
||||||
|
)
|
||||||
|
# TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps
|
||||||
|
# configure lr scheduler
|
||||||
|
lr_scheduler = get_scheduler(
|
||||||
|
cfg.policy.lr_scheduler,
|
||||||
|
optimizer=optimizer,
|
||||||
|
num_warmup_steps=cfg.policy.lr_warmup_steps,
|
||||||
|
num_training_steps=cfg.offline_steps,
|
||||||
|
# pytorch assumes stepping LRScheduler every epoch
|
||||||
|
# however huggingface diffusers steps it every batch
|
||||||
|
last_epoch=-1,
|
||||||
|
)
|
||||||
|
elif policy.name == "tdmpc":
|
||||||
|
raise NotImplementedError("TD-MPC not implemented yet.")
|
||||||
|
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
|
||||||
|
@ -292,7 +355,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
for key in batch:
|
for key in batch:
|
||||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||||
|
|
||||||
train_info = policy.update(batch, step=step)
|
train_info = update_policy(policy, batch, optimizer, cfg.policy.grad_clip_norm, lr_scheduler)
|
||||||
|
|
||||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||||
if step % cfg.log_freq == 0:
|
if step % cfg.log_freq == 0:
|
||||||
|
@ -358,7 +421,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
for key in batch:
|
for key in batch:
|
||||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||||
|
|
||||||
train_info = policy.update(batch, step)
|
train_info = update_policy(policy, batch, optimizer, cfg.policy.grad_clip_norm, lr_scheduler)
|
||||||
|
|
||||||
if step % cfg.log_freq == 0:
|
if step % cfg.log_freq == 0:
|
||||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
|
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
|
||||||
|
|
|
@ -18,8 +18,8 @@ from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_name,policy_name,extra_overrides",
|
"env_name,policy_name,extra_overrides",
|
||||||
[
|
[
|
||||||
("xarm", "tdmpc", ["policy.mpc=true"]),
|
# ("xarm", "tdmpc", ["policy.mpc=true"]),
|
||||||
("pusht", "tdmpc", ["policy.mpc=false"]),
|
# ("pusht", "tdmpc", ["policy.mpc=false"]),
|
||||||
("pusht", "diffusion", []),
|
("pusht", "diffusion", []),
|
||||||
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_human"]),
|
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_human"]),
|
||||||
(
|
(
|
||||||
|
@ -86,7 +86,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
||||||
|
|
||||||
# Test updating the policy
|
# Test updating the policy
|
||||||
policy.update(batch, step=0)
|
policy.forward(batch, step=0)
|
||||||
|
|
||||||
# reset the policy and environment
|
# reset the policy and environment
|
||||||
policy.reset()
|
policy.reset()
|
||||||
|
|
Loading…
Reference in New Issue