This commit is contained in:
Cadene 2024-05-06 00:49:33 +00:00
parent cf4b4c5a18
commit 764a7ad2c3
3 changed files with 3 additions and 3 deletions

View File

@ -80,7 +80,7 @@ print(f"{dataset[0]['action'].shape=}\n") # (64,c)
# PyTorch datasets. # PyTorch datasets.
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,
num_workers=1, num_workers=0,
batch_size=32, batch_size=32,
shuffle=True, shuffle=True,
) )

View File

@ -49,7 +49,7 @@ optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
# Create dataloader for offline training. # Create dataloader for offline training.
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,
num_workers=4, num_workers=0,
batch_size=64, batch_size=64,
shuffle=True, shuffle=True,
pin_memory=device != torch.device("cpu"), pin_memory=device != torch.device("cpu"),

View File

@ -37,7 +37,7 @@ def test_examples_3_and_2():
file_contents, file_contents,
[ [
("training_steps = 5000", "training_steps = 1"), ("training_steps = 5000", "training_steps = 1"),
("num_workers=4", "num_workers=0"), # ("num_workers=4", "num_workers=0"),
('device = torch.device("cuda")', 'device = torch.device("cpu")'), ('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("batch_size=64", "batch_size=1"), ("batch_size=64", "batch_size=1"),
], ],