Fix multiprocessing error on Windows by adding main guard
This commit is contained in:
parent
b32db2549c
commit
c6e12f4d45
|
@ -12,6 +12,8 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
# Create a directory to store the training checkpoint.
|
# Create a directory to store the training checkpoint.
|
||||||
output_directory = Path("outputs/train/example_pusht_diffusion")
|
output_directory = Path("outputs/train/example_pusht_diffusion")
|
||||||
output_directory.mkdir(parents=True, exist_ok=True)
|
output_directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -49,7 +51,7 @@ optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
|
||||||
# Create dataloader for offline training.
|
# Create dataloader for offline training.
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
num_workers=4,
|
num_workers=0,
|
||||||
batch_size=64,
|
batch_size=64,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=device != torch.device("cpu"),
|
pin_memory=device != torch.device("cpu"),
|
||||||
|
@ -77,3 +79,7 @@ while not done:
|
||||||
|
|
||||||
# Save a policy checkpoint.
|
# Save a policy checkpoint.
|
||||||
policy.save_pretrained(output_directory)
|
policy.save_pretrained(output_directory)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
Loading…
Reference in New Issue