lerobot/examples/3_train_policy.py

75 lines
2.6 KiB
Python
Raw Normal View History

2024-03-27 00:13:40 +08:00
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
Once you have trained a model with this script, you can try to evaluate it on
examples/2_evaluate_pretrained_policy.py
"""
import os
from pathlib import Path
import torch
from omegaconf import OmegaConf
2024-04-10 21:45:45 +08:00
from lerobot.common.datasets.factory import make_dataset
2024-04-16 19:51:32 +08:00
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
2024-04-16 02:06:44 +08:00
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
2024-03-27 00:13:40 +08:00
from lerobot.common.utils import init_hydra_config
output_directory = Path("outputs/train/example_pusht_diffusion")
os.makedirs(output_directory, exist_ok=True)
2024-04-16 19:51:32 +08:00
# Number of offline training steps (we'll only do offline training for this example.
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
training_steps = 5000
device = torch.device("cuda")
log_freq = 250
2024-03-27 00:13:40 +08:00
2024-04-16 19:51:32 +08:00
# Set up the dataset.
2024-04-16 20:43:58 +08:00
hydra_cfg = init_hydra_config("lerobot/configs/default.yaml", overrides=["env=pusht"])
dataset = make_dataset(hydra_cfg)
2024-03-27 00:13:40 +08:00
2024-04-16 19:51:32 +08:00
# Set up the the policy.
# Policies are initialized with a configuration class, in this case `DiffusionConfig`.
# For this example, no arguments need to be passed because the defaults are set up for PushT.
# If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig()
# TODO(alexander-soare): Remove LR scheduler from the policy.
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps)
2024-03-27 00:13:40 +08:00
policy.train()
2024-04-16 19:51:32 +08:00
policy.to(device)
2024-03-27 00:13:40 +08:00
2024-04-16 19:51:32 +08:00
# Create dataloader for offline training.
2024-04-10 21:45:45 +08:00
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
2024-04-16 19:51:32 +08:00
batch_size=cfg.batch_size,
2024-04-10 21:45:45 +08:00
shuffle=True,
2024-04-16 19:51:32 +08:00
pin_memory=device != torch.device("cpu"),
2024-04-10 21:45:45 +08:00
drop_last=True,
)
2024-04-16 19:51:32 +08:00
# Run training loop.
2024-04-16 20:43:58 +08:00
step = 0
done = False
while not done:
for batch in dataloader:
2024-04-16 21:07:16 +08:00
for k in batch:
batch[k] = batch[k].to(device, non_blocking=True)
2024-04-16 20:43:58 +08:00
info = policy(batch)
if step % log_freq == 0:
num_samples = (step + 1) * cfg.batch_size
loss = info["loss"]
update_s = info["update_s"]
print(
f"step: {step} samples: {num_samples} loss: {loss:.3f} update_time: {update_s:.3f} (seconds)"
)
step += 1
if step >= training_steps:
done = True
break
2024-03-27 00:13:40 +08:00
# Save the policy, configuration, and normalization stats for later use.
policy.save(output_directory / "model.pt")
2024-04-16 20:43:58 +08:00
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
2024-04-10 21:45:45 +08:00
torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")