fix argmax axis + fix seed

This commit is contained in:
Wael Karkoub 2024-06-10 16:51:06 +01:00
parent 1bcaad9bae
commit 18b8e95067
1 changed files with 5 additions and 4 deletions

View File

@ -293,9 +293,7 @@ def eval_policy(
# this won't be included).
n_steps = rollout_data["done"].shape[1]
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
done_indices = torch.argmax(
rollout_data["done"].to(int), axis=1
) # (batch_size, rollout_steps) TODO: can't find any docs for the axis arg.
done_indices = torch.argmax(rollout_data["done"].to(int), dim=1)
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
@ -307,7 +305,10 @@ def eval_policy(
max_rewards.extend(batch_max_rewards.tolist())
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
all_successes.extend(batch_successes.tolist())
all_seeds.extend(seeds)
if seeds:
all_seeds.extend(seeds)
else:
all_seeds.append(None)
# FIXME: episode_data is either None or it doesn't exist
if return_episode_data: