Fix test_examples
This commit is contained in:
parent
9c2f10bd04
commit
43a614c173
|
@ -53,7 +53,8 @@ step = 0
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
for k in batch:
|
||||||
|
batch[k] = batch[k].to(device, non_blocking=True)
|
||||||
info = policy(batch)
|
info = policy(batch)
|
||||||
if step % log_freq == 0:
|
if step % log_freq == 0:
|
||||||
num_samples = (step + 1) * cfg.batch_size
|
num_samples = (step + 1) * cfg.batch_size
|
||||||
|
|
|
@ -29,13 +29,14 @@ def test_examples_3_and_2():
|
||||||
with open(path, "r") as file:
|
with open(path, "r") as file:
|
||||||
file_contents = file.read()
|
file_contents = file.read()
|
||||||
|
|
||||||
# Do less steps, use CPU, and don't complicate things with dataloader workers.
|
# Do less steps, use smaller batch, use CPU, and don't complicate things with dataloader workers.
|
||||||
file_contents = _find_and_replace(
|
file_contents = _find_and_replace(
|
||||||
file_contents,
|
file_contents,
|
||||||
[
|
[
|
||||||
("training_steps = 5000", "training_steps = 1"),
|
("training_steps = 5000", "training_steps = 1"),
|
||||||
("num_workers=4", "num_workers=0"),
|
("num_workers=4", "num_workers=0"),
|
||||||
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
|
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
|
||||||
|
("batch_size=cfg.batch_size", "batch_size=1"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue