pass step as kwarg
This commit is contained in:
parent
2ccf89d78c
commit
9241b5e830
|
@ -130,7 +130,7 @@ def eval_policy(
|
|||
|
||||
# get the next action for the environment
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation, step)
|
||||
action = policy.select_action(observation, step=step)
|
||||
|
||||
# apply inverse transform to unnormalize the action
|
||||
action = postprocess_action(action, transform)
|
||||
|
|
|
@ -251,7 +251,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
|
||||
train_info = policy(batch, step)
|
||||
train_info = policy(batch, step=step)
|
||||
|
||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||
if step % cfg.log_freq == 0:
|
||||
|
|
Loading…
Reference in New Issue