From c6e12f4d45e051a6a9073a3a685e68f361e0392e Mon Sep 17 00:00:00 2001 From: Jakob Ganitzer Date: Fri, 31 May 2024 20:02:02 +0200 Subject: [PATCH] Fix multiprocessing error on Windows by adding main guard --- examples/3_train_policy.py | 120 +++++++++++++++++++------------------ 1 file changed, 63 insertions(+), 57 deletions(-) diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index c5ce0d18..d2696c18 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -12,68 +12,74 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy -# Create a directory to store the training checkpoint. -output_directory = Path("outputs/train/example_pusht_diffusion") -output_directory.mkdir(parents=True, exist_ok=True) -# Number of offline training steps (we'll only do offline training for this example.) -# Adjust as you prefer. 5000 steps are needed to get something worth evaluating. -training_steps = 5000 -device = torch.device("cuda") -log_freq = 250 +def main(): + # Create a directory to store the training checkpoint. + output_directory = Path("outputs/train/example_pusht_diffusion") + output_directory.mkdir(parents=True, exist_ok=True) -# Set up the dataset. -delta_timestamps = { - # Load the previous image and state at -0.1 seconds before current frame, - # then load current image and state corresponding to 0.0 second. - "observation.image": [-0.1, 0.0], - "observation.state": [-0.1, 0.0], - # Load the previous action (-0.1), the next action to be executed (0.0), - # and 14 future actions with a 0.1 seconds spacing. All these actions will be - # used to supervise the policy. - "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], -} -dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps) + # Number of offline training steps (we'll only do offline training for this example.) + # Adjust as you prefer. 5000 steps are needed to get something worth evaluating. + training_steps = 5000 + device = torch.device("cuda") + log_freq = 250 -# Set up the the policy. -# Policies are initialized with a configuration class, in this case `DiffusionConfig`. -# For this example, no arguments need to be passed because the defaults are set up for PushT. -# If you're doing something different, you will likely need to change at least some of the defaults. -cfg = DiffusionConfig() -policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats) -policy.train() -policy.to(device) + # Set up the dataset. + delta_timestamps = { + # Load the previous image and state at -0.1 seconds before current frame, + # then load current image and state corresponding to 0.0 second. + "observation.image": [-0.1, 0.0], + "observation.state": [-0.1, 0.0], + # Load the previous action (-0.1), the next action to be executed (0.0), + # and 14 future actions with a 0.1 seconds spacing. All these actions will be + # used to supervise the policy. + "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], + } + dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps) -optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4) + # Set up the the policy. + # Policies are initialized with a configuration class, in this case `DiffusionConfig`. + # For this example, no arguments need to be passed because the defaults are set up for PushT. + # If you're doing something different, you will likely need to change at least some of the defaults. + cfg = DiffusionConfig() + policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats) + policy.train() + policy.to(device) -# Create dataloader for offline training. -dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=4, - batch_size=64, - shuffle=True, - pin_memory=device != torch.device("cpu"), - drop_last=True, -) + optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4) -# Run training loop. -step = 0 -done = False -while not done: - for batch in dataloader: - batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} - output_dict = policy.forward(batch) - loss = output_dict["loss"] - loss.backward() - optimizer.step() - optimizer.zero_grad() + # Create dataloader for offline training. + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=0, + batch_size=64, + shuffle=True, + pin_memory=device != torch.device("cpu"), + drop_last=True, + ) - if step % log_freq == 0: - print(f"step: {step} loss: {loss.item():.3f}") - step += 1 - if step >= training_steps: - done = True - break + # Run training loop. + step = 0 + done = False + while not done: + for batch in dataloader: + batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} + output_dict = policy.forward(batch) + loss = output_dict["loss"] + loss.backward() + optimizer.step() + optimizer.zero_grad() -# Save a policy checkpoint. -policy.save_pretrained(output_directory) + if step % log_freq == 0: + print(f"step: {step} loss: {loss.item():.3f}") + step += 1 + if step >= training_steps: + done = True + break + + # Save a policy checkpoint. + policy.save_pretrained(output_directory) + + +if __name__ == "__main__": + main()