Merge branch 'main' into user/alexander-soare/multistep_policy_and_serial_env

This commit is contained in:
Alexander Soare 2024-03-18 18:27:50 +00:00
commit 09ddd9bf92
2 changed files with 2 additions and 0 deletions

View File

@ -33,6 +33,7 @@ def eval_policy(
fps: int = 15, fps: int = 15,
return_first_video: bool = False, return_first_video: bool = False,
): ):
policy.eval()
start = time.time() start = time.time()
sum_rewards = [] sum_rewards = []
max_rewards = [] max_rewards = []

View File

@ -190,6 +190,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
if offline_step == 0: if offline_step == 0:
logging.info("Start offline training on a fixed dataset") logging.info("Start offline training on a fixed dataset")
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
policy.train()
train_info = policy.update(offline_buffer, step) train_info = policy.update(offline_buffer, step)
if step % cfg.log_freq == 0: if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline) log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)