Remove redundant slicing operation in Diffusion Policy (#240)

This commit is contained in:
Alexander Soare 2024-06-03 13:04:24 +01:00 committed by GitHub
parent 042e193995
commit cf15cba5fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 3 deletions

View File

@ -239,10 +239,8 @@ class DiffusionModel(nn.Module):
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
# run sampling # run sampling
sample = self.conditional_sample(batch_size, global_cond=global_cond) actions = self.conditional_sample(batch_size, global_cond=global_cond)
# `horizon` steps worth of actions (from the first observation).
actions = sample[..., : self.config.output_shapes["action"][0]]
# Extract `n_action_steps` steps worth of actions (from the current observation). # Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1 start = n_obs_steps - 1
end = start + self.config.n_action_steps end = start + self.config.n_action_steps