Fix multiprocessing error on Windows by adding main guard

This commit is contained in:
Jakob Ganitzer 2024-05-31 20:02:02 +02:00
parent b32db2549c
commit c6e12f4d45
1 changed files with 63 additions and 57 deletions

View File

@ -12,6 +12,8 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
def main():
# Create a directory to store the training checkpoint. # Create a directory to store the training checkpoint.
output_directory = Path("outputs/train/example_pusht_diffusion") output_directory = Path("outputs/train/example_pusht_diffusion")
output_directory.mkdir(parents=True, exist_ok=True) output_directory.mkdir(parents=True, exist_ok=True)
@ -49,7 +51,7 @@ optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
# Create dataloader for offline training. # Create dataloader for offline training.
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,
num_workers=4, num_workers=0,
batch_size=64, batch_size=64,
shuffle=True, shuffle=True,
pin_memory=device != torch.device("cpu"), pin_memory=device != torch.device("cpu"),
@ -77,3 +79,7 @@ while not done:
# Save a policy checkpoint. # Save a policy checkpoint.
policy.save_pretrained(output_directory) policy.save_pretrained(output_directory)
if __name__ == "__main__":
main()