Remove redundant slicing operation in Diffusion Policy (#240)
This commit is contained in:
parent
042e193995
commit
cf15cba5fc
|
@ -239,10 +239,8 @@ class DiffusionModel(nn.Module):
|
||||||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
||||||
|
|
||||||
# run sampling
|
# run sampling
|
||||||
sample = self.conditional_sample(batch_size, global_cond=global_cond)
|
actions = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||||
|
|
||||||
# `horizon` steps worth of actions (from the first observation).
|
|
||||||
actions = sample[..., : self.config.output_shapes["action"][0]]
|
|
||||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||||
start = n_obs_steps - 1
|
start = n_obs_steps - 1
|
||||||
end = start + self.config.n_action_steps
|
end = start + self.config.n_action_steps
|
||||||
|
|
Loading…
Reference in New Issue