2024-03-27 00:13:40 +08:00
|
|
|
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
|
|
|
|
|
|
|
Once you have trained a model with this script, you can try to evaluate it on
|
|
|
|
examples/2_evaluate_pretrained_policy.py
|
|
|
|
"""
|
|
|
|
|
|
|
|
import os
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
|
2024-04-10 21:45:45 +08:00
|
|
|
from lerobot.common.datasets.factory import make_dataset
|
2024-04-16 19:51:32 +08:00
|
|
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
2024-04-16 02:06:44 +08:00
|
|
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
2024-04-18 20:47:42 +08:00
|
|
|
from lerobot.common.utils.utils import init_hydra_config
|
2024-03-27 00:13:40 +08:00
|
|
|
|
|
|
|
output_directory = Path("outputs/train/example_pusht_diffusion")
|
|
|
|
os.makedirs(output_directory, exist_ok=True)
|
|
|
|
|
2024-04-16 19:51:32 +08:00
|
|
|
# Number of offline training steps (we'll only do offline training for this example.
|
|
|
|
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
|
|
|
training_steps = 5000
|
|
|
|
device = torch.device("cuda")
|
|
|
|
log_freq = 250
|
2024-03-27 00:13:40 +08:00
|
|
|
|
2024-04-16 19:51:32 +08:00
|
|
|
# Set up the dataset.
|
2024-04-16 20:43:58 +08:00
|
|
|
hydra_cfg = init_hydra_config("lerobot/configs/default.yaml", overrides=["env=pusht"])
|
|
|
|
dataset = make_dataset(hydra_cfg)
|
2024-03-27 00:13:40 +08:00
|
|
|
|
2024-04-16 19:51:32 +08:00
|
|
|
# Set up the the policy.
|
|
|
|
# Policies are initialized with a configuration class, in this case `DiffusionConfig`.
|
|
|
|
# For this example, no arguments need to be passed because the defaults are set up for PushT.
|
|
|
|
# If you're doing something different, you will likely need to change at least some of the defaults.
|
|
|
|
cfg = DiffusionConfig()
|
|
|
|
# TODO(alexander-soare): Remove LR scheduler from the policy.
|
2024-04-25 17:47:38 +08:00
|
|
|
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats)
|
2024-03-27 00:13:40 +08:00
|
|
|
policy.train()
|
2024-04-16 19:51:32 +08:00
|
|
|
policy.to(device)
|
2024-03-27 00:13:40 +08:00
|
|
|
|
2024-04-30 23:08:59 +08:00
|
|
|
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
|
2024-04-29 18:27:58 +08:00
|
|
|
|
2024-04-16 19:51:32 +08:00
|
|
|
# Create dataloader for offline training.
|
2024-04-10 21:45:45 +08:00
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
|
|
dataset,
|
|
|
|
num_workers=4,
|
2024-04-30 23:08:59 +08:00
|
|
|
batch_size=64,
|
2024-04-10 21:45:45 +08:00
|
|
|
shuffle=True,
|
2024-04-16 19:51:32 +08:00
|
|
|
pin_memory=device != torch.device("cpu"),
|
2024-04-10 21:45:45 +08:00
|
|
|
drop_last=True,
|
|
|
|
)
|
|
|
|
|
2024-04-16 19:51:32 +08:00
|
|
|
# Run training loop.
|
2024-04-16 20:43:58 +08:00
|
|
|
step = 0
|
|
|
|
done = False
|
|
|
|
while not done:
|
|
|
|
for batch in dataloader:
|
2024-04-17 00:15:51 +08:00
|
|
|
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
2024-04-29 18:27:58 +08:00
|
|
|
output_dict = policy.forward(batch)
|
|
|
|
loss = output_dict["loss"]
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
2024-04-16 20:43:58 +08:00
|
|
|
if step % log_freq == 0:
|
2024-04-29 18:27:58 +08:00
|
|
|
print(f"step: {step} loss: {loss.item():.3f}")
|
2024-04-16 20:43:58 +08:00
|
|
|
step += 1
|
|
|
|
if step >= training_steps:
|
|
|
|
done = True
|
|
|
|
break
|
2024-03-27 00:13:40 +08:00
|
|
|
|
2024-04-25 17:47:38 +08:00
|
|
|
# Save the policy and configuration for later use.
|
2024-03-27 00:13:40 +08:00
|
|
|
policy.save(output_directory / "model.pt")
|
2024-04-16 20:43:58 +08:00
|
|
|
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
|