From 43a614c17371bbc39f06111c45858c0964c63358 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 16 Apr 2024 14:07:16 +0100 Subject: [PATCH] Fix test_examples --- examples/3_train_policy.py | 3 ++- tests/test_examples.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index 83563ffd..d2e8b8c9 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -53,7 +53,8 @@ step = 0 done = False while not done: for batch in dataloader: - batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} + for k in batch: + batch[k] = batch[k].to(device, non_blocking=True) info = policy(batch) if step % log_freq == 0: num_samples = (step + 1) * cfg.batch_size diff --git a/tests/test_examples.py b/tests/test_examples.py index c264610b..6cab7a1a 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -29,13 +29,14 @@ def test_examples_3_and_2(): with open(path, "r") as file: file_contents = file.read() - # Do less steps, use CPU, and don't complicate things with dataloader workers. + # Do less steps, use smaller batch, use CPU, and don't complicate things with dataloader workers. file_contents = _find_and_replace( file_contents, [ ("training_steps = 5000", "training_steps = 1"), ("num_workers=4", "num_workers=0"), ('device = torch.device("cuda")', 'device = torch.device("cpu")'), + ("batch_size=cfg.batch_size", "batch_size=1"), ], )