Fix test_examples

This commit is contained in:
Alexander Soare 2024-04-16 14:07:16 +01:00
parent 9c2f10bd04
commit 43a614c173
2 changed files with 4 additions and 2 deletions

View File

@ -53,7 +53,8 @@ step = 0
done = False
while not done:
for batch in dataloader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
for k in batch:
batch[k] = batch[k].to(device, non_blocking=True)
info = policy(batch)
if step % log_freq == 0:
num_samples = (step + 1) * cfg.batch_size

View File

@ -29,13 +29,14 @@ def test_examples_3_and_2():
with open(path, "r") as file:
file_contents = file.read()
# Do less steps, use CPU, and don't complicate things with dataloader workers.
# Do less steps, use smaller batch, use CPU, and don't complicate things with dataloader workers.
file_contents = _find_and_replace(
file_contents,
[
("training_steps = 5000", "training_steps = 1"),
("num_workers=4", "num_workers=0"),
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("batch_size=cfg.batch_size", "batch_size=1"),
],
)