2024-03-27 00:13:40 +08:00
|
|
|
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
|
|
|
|
|
|
|
Once you have trained a model with this script, you can try to evaluate it on
|
|
|
|
examples/2_evaluate_pretrained_policy.py
|
|
|
|
"""
|
|
|
|
|
|
|
|
import os
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
from tqdm import trange
|
|
|
|
|
|
|
|
from lerobot.common.datasets.factory import make_offline_buffer
|
|
|
|
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
|
|
|
from lerobot.common.utils import init_hydra_config
|
|
|
|
|
|
|
|
output_directory = Path("outputs/train/example_pusht_diffusion")
|
|
|
|
os.makedirs(output_directory, exist_ok=True)
|
|
|
|
|
|
|
|
overrides = [
|
|
|
|
"env=pusht",
|
|
|
|
"policy=diffusion",
|
|
|
|
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
|
|
|
"offline_steps=5000",
|
|
|
|
"log_freq=250",
|
|
|
|
"device=cuda",
|
|
|
|
]
|
|
|
|
|
|
|
|
cfg = init_hydra_config("lerobot/configs/default.yaml", overrides)
|
|
|
|
|
|
|
|
policy = DiffusionPolicy(
|
|
|
|
cfg=cfg.policy,
|
|
|
|
cfg_device=cfg.device,
|
|
|
|
cfg_noise_scheduler=cfg.noise_scheduler,
|
|
|
|
cfg_rgb_model=cfg.rgb_model,
|
|
|
|
cfg_obs_encoder=cfg.obs_encoder,
|
|
|
|
cfg_optimizer=cfg.optimizer,
|
|
|
|
cfg_ema=cfg.ema,
|
|
|
|
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
|
|
|
**cfg.policy,
|
|
|
|
)
|
|
|
|
policy.train()
|
|
|
|
|
|
|
|
offline_buffer = make_offline_buffer(cfg)
|
|
|
|
|
|
|
|
for offline_step in trange(cfg.offline_steps):
|
|
|
|
train_info = policy.update(offline_buffer, offline_step)
|
|
|
|
if offline_step % cfg.log_freq == 0:
|
|
|
|
print(train_info)
|
|
|
|
|
|
|
|
# Save the policy, configuration, and normalization stats for later use.
|
2024-04-03 04:45:21 +08:00
|
|
|
policy.save_pretrained(output_directory / "model.pt")
|
2024-03-27 00:13:40 +08:00
|
|
|
OmegaConf.save(cfg, output_directory / "config.yaml")
|
|
|
|
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")
|