fix, it's training now!

This commit is contained in:
Cadene 2024-03-24 23:10:16 +00:00 committed by Simon Alibert
parent 127de1258d
commit be6364f109
1 changed files with 16 additions and 1 deletions

View File

@ -224,7 +224,22 @@ def train(cfg: dict, out_dir=None, job_name=None):
policy=td_policy,
auto_cast_to_device=True,
)
assert len(rollout) <= cfg.env.episode_length
assert (
len(rollout.batch_size) == 2
), "2 dimensions expected: number of env in parallel x max number of steps during rollout"
num_parallel_env = rollout.batch_size[0]
if num_parallel_env != 1:
# TODO(rcadene): when num_parallel_env > 1, episode needs to be incremented and we need to add tests
raise NotImplementedError()
num_max_steps = rollout.batch_size[1]
assert num_max_steps <= cfg.env.episode_length
# reshape to have a list of steps to insert into online_buffer
rollout = rollout.reshape(num_parallel_env * num_max_steps)
# set same episode index for all time steps contained in this rollout
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
online_buffer.extend(rollout)