pass step as kwarg
This commit is contained in:
parent
2ccf89d78c
commit
9241b5e830
|
@ -130,7 +130,7 @@ def eval_policy(
|
||||||
|
|
||||||
# get the next action for the environment
|
# get the next action for the environment
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
action = policy.select_action(observation, step)
|
action = policy.select_action(observation, step=step)
|
||||||
|
|
||||||
# apply inverse transform to unnormalize the action
|
# apply inverse transform to unnormalize the action
|
||||||
action = postprocess_action(action, transform)
|
action = postprocess_action(action, transform)
|
||||||
|
|
|
@ -251,7 +251,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
for key in batch:
|
for key in batch:
|
||||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||||
|
|
||||||
train_info = policy(batch, step)
|
train_info = policy(batch, step=step)
|
||||||
|
|
||||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||||
if step % cfg.log_freq == 0:
|
if step % cfg.log_freq == 0:
|
||||||
|
|
Loading…
Reference in New Issue