diff --git a/lerobot/common/envs/pusht.py b/lerobot/common/envs/pusht.py index ab979b38..927a1ba7 100644 --- a/lerobot/common/envs/pusht.py +++ b/lerobot/common/envs/pusht.py @@ -157,13 +157,20 @@ class PushtEnv(EnvBase): def _step(self, tensordict: TensorDict): td = tensordict - # remove batch dim - action = td["action"].squeeze(0).numpy() + action = td["action"].numpy() # step expects shape=(4,) so we pad if necessary # TODO(rcadene): add info["is_success"] and info["success"] ? sum_reward = 0 - for _ in range(self.frame_skip): - raw_obs, reward, done, info = self._env.step(action) + + if action.ndim == 1: + action = action.repeat(self.frame_skip, 1) + else: + if self.frame_skip > 1: + raise NotImplementedError() + + num_action_steps = action.shape[0] + for i in range(num_action_steps): + raw_obs, reward, done, info = self._env.step(action[i]) sum_reward += reward obs = self._format_raw_obs(raw_obs) diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index aeec502e..df05bfd8 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -12,8 +12,6 @@ from diffusion_policy.model.vision.model_getter import get_resnet from .diffusion_unet_image_policy import DiffusionUnetImagePolicy from .multi_image_obs_encoder import MultiImageObsEncoder -FIRST_ACTION = 0 - class DiffusionPolicy(nn.Module): def __init__( @@ -110,8 +108,7 @@ class DiffusionPolicy(nn.Module): } out = self.diffusion.predict_action(obs_dict) - # TODO(rcadene): add possibility to return >1 timestemps - action = out["action"].squeeze(0)[FIRST_ACTION] + action = out["action"].squeeze(0) return action def update(self, replay_buffer, step):