Fix multiprocessing error on Windows by adding main guard

This commit is contained in:
Jakob Ganitzer 2024-05-31 20:02:02 +02:00
parent b32db2549c
commit c6e12f4d45
1 changed files with 63 additions and 57 deletions

View File

@ -12,18 +12,20 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
# Create a directory to store the training checkpoint.
output_directory = Path("outputs/train/example_pusht_diffusion")
output_directory.mkdir(parents=True, exist_ok=True)
# Number of offline training steps (we'll only do offline training for this example.) def main():
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating. # Create a directory to store the training checkpoint.
training_steps = 5000 output_directory = Path("outputs/train/example_pusht_diffusion")
device = torch.device("cuda") output_directory.mkdir(parents=True, exist_ok=True)
log_freq = 250
# Set up the dataset. # Number of offline training steps (we'll only do offline training for this example.)
delta_timestamps = { # Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
training_steps = 5000
device = torch.device("cuda")
log_freq = 250
# Set up the dataset.
delta_timestamps = {
# Load the previous image and state at -0.1 seconds before current frame, # Load the previous image and state at -0.1 seconds before current frame,
# then load current image and state corresponding to 0.0 second. # then load current image and state corresponding to 0.0 second.
"observation.image": [-0.1, 0.0], "observation.image": [-0.1, 0.0],
@ -32,34 +34,34 @@ delta_timestamps = {
# and 14 future actions with a 0.1 seconds spacing. All these actions will be # and 14 future actions with a 0.1 seconds spacing. All these actions will be
# used to supervise the policy. # used to supervise the policy.
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
} }
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps) dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
# Set up the the policy. # Set up the the policy.
# Policies are initialized with a configuration class, in this case `DiffusionConfig`. # Policies are initialized with a configuration class, in this case `DiffusionConfig`.
# For this example, no arguments need to be passed because the defaults are set up for PushT. # For this example, no arguments need to be passed because the defaults are set up for PushT.
# If you're doing something different, you will likely need to change at least some of the defaults. # If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig() cfg = DiffusionConfig()
policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats) policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
policy.train() policy.train()
policy.to(device) policy.to(device)
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4) optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
# Create dataloader for offline training. # Create dataloader for offline training.
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,
num_workers=4, num_workers=0,
batch_size=64, batch_size=64,
shuffle=True, shuffle=True,
pin_memory=device != torch.device("cpu"), pin_memory=device != torch.device("cpu"),
drop_last=True, drop_last=True,
) )
# Run training loop. # Run training loop.
step = 0 step = 0
done = False done = False
while not done: while not done:
for batch in dataloader: for batch in dataloader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
output_dict = policy.forward(batch) output_dict = policy.forward(batch)
@ -75,5 +77,9 @@ while not done:
done = True done = True
break break
# Save a policy checkpoint. # Save a policy checkpoint.
policy.save_pretrained(output_directory) policy.save_pretrained(output_directory)
if __name__ == "__main__":
main()