diff --git a/lerobot/common/envs/pusht.py b/lerobot/common/envs/pusht.py index 39bf3bba..ab979b38 100644 --- a/lerobot/common/envs/pusht.py +++ b/lerobot/common/envs/pusht.py @@ -1,4 +1,5 @@ import importlib +from collections import deque from typing import Optional import torch @@ -27,12 +28,16 @@ class PushtEnv(EnvBase): image_size=None, seed=1337, device="cpu", + num_prev_obs=1, + num_prev_action=0, ): super().__init__(device=device, batch_size=[]) self.frame_skip = frame_skip self.from_pixels = from_pixels self.pixels_only = pixels_only self.image_size = image_size + self.num_prev_obs = num_prev_obs + self.num_prev_action = num_prev_action if pixels_only: assert from_pixels @@ -56,6 +61,12 @@ class PushtEnv(EnvBase): self._make_spec() self._current_seed = self.set_seed(seed) + if self.num_prev_obs > 0: + self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs) + self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs) + if self.num_prev_action > 0: + self._prev_action_queue = deque(maxlen=self.num_prev_action) + def render(self, mode="rgb_array", width=384, height=384): if width != height: raise NotImplementedError() @@ -67,7 +78,8 @@ class PushtEnv(EnvBase): def _format_raw_obs(self, raw_obs): if self.from_pixels: - obs = {"image": torch.from_numpy(raw_obs["image"])} + image = torch.from_numpy(raw_obs["image"]) + obs = {"image": image} if not self.pixels_only: obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32) @@ -75,7 +87,6 @@ class PushtEnv(EnvBase): # TODO: obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)} - obs = TensorDict(obs, batch_size=[]) return obs def _reset(self, tensordict: Optional[TensorDict] = None): @@ -87,9 +98,21 @@ class PushtEnv(EnvBase): raw_obs = self._env.reset() assert self._current_seed == self._env._seed + obs = self._format_raw_obs(raw_obs) + + if self.num_prev_obs > 0: + # remove all previous observations + if "image" in obs: + self._prev_obs_image_queue.clear() + if "state" in obs: + self._prev_obs_state_queue.clear() + + # copy the current observation n times + obs = self._stack_prev_obs(obs) + td = TensorDict( { - "observation": self._format_raw_obs(raw_obs), + "observation": TensorDict(obs, batch_size=[]), "done": torch.tensor([False], dtype=torch.bool), }, batch_size=[], @@ -98,6 +121,40 @@ class PushtEnv(EnvBase): raise NotImplementedError() return td + def _stack_prev_obs(self, obs): + """When the queue is empty, copy the current observation n times.""" + assert self.num_prev_obs > 0 + + def stack_update_queue(prev_obs_queue, obs, num_prev_obs): + # get n most recent observations + prev_obs = list(prev_obs_queue)[-num_prev_obs:] + + # if not enough observations, copy the oldest observation until we obtain n observations + if len(prev_obs) == 0: + prev_obs = [obs] * num_prev_obs # queue is empty when env reset + elif len(prev_obs) < num_prev_obs: + prev_obs = [prev_obs[0] for _ in range(num_prev_obs - len(prev_obs))] + prev_obs + + # stack n most recent observations with the current observation + stacked_obs = torch.stack(prev_obs + [obs], dim=0) + + # add current observation to the queue + # automatically remove oldest observation when queue is full + prev_obs_queue.appendleft(obs) + + return stacked_obs + + stacked_obs = {} + if "image" in obs: + stacked_obs["image"] = stack_update_queue( + self._prev_obs_image_queue, obs["image"], self.num_prev_obs + ) + if "state" in obs: + stacked_obs["state"] = stack_update_queue( + self._prev_obs_state_queue, obs["state"], self.num_prev_obs + ) + return stacked_obs + def _step(self, tensordict: TensorDict): td = tensordict # remove batch dim @@ -109,9 +166,14 @@ class PushtEnv(EnvBase): raw_obs, reward, done, info = self._env.step(action) sum_reward += reward + obs = self._format_raw_obs(raw_obs) + + if self.num_prev_obs > 0: + obs = self._stack_prev_obs(obs) + td = TensorDict( { - "observation": self._format_raw_obs(raw_obs), + "observation": TensorDict(obs, batch_size=[]), "reward": torch.tensor([sum_reward], dtype=torch.float32), # succes and done are true when coverage > self.success_threshold in env "done": torch.tensor([done], dtype=torch.bool), @@ -124,14 +186,22 @@ class PushtEnv(EnvBase): def _make_spec(self): obs = {} if self.from_pixels: + image_shape = (3, self.image_size, self.image_size) + if self.num_prev_obs > 0: + image_shape = (self.num_prev_obs, *image_shape) + obs["image"] = BoundedTensorSpec( low=0, high=1, - shape=(3, self.image_size, self.image_size), + shape=image_shape, dtype=torch.float32, device=self.device, ) if not self.pixels_only: + state_shape = self._env.observation_space["agent_pos"].shape + if self.num_prev_obs > 0: + state_shape = (self.num_prev_obs, *state_shape) + obs["state"] = BoundedTensorSpec( low=0, high=512, @@ -141,6 +211,10 @@ class PushtEnv(EnvBase): ) else: # TODO(rcadene): add observation_space achieved_goal and desired_goal? + state_shape = self._env.observation_space["observation"].shape + if self.num_prev_obs > 0: + state_shape = (self.num_prev_obs, *state_shape) + obs["state"] = UnboundedContinuousTensorSpec( # TODO: shape=self._env.observation_space["observation"].shape, diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index a484c65a..aeec502e 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -1,7 +1,6 @@ import copy import time -import einops import hydra import torch import torch.nn as nn @@ -101,15 +100,13 @@ class DiffusionPolicy(nn.Module): # TODO(rcadene): remove unused step_count del step_count - # TODO(rcadene): remove unsqueeze hack... - if observation["image"].ndim == 3: - observation["image"] = observation["image"].unsqueeze(0) - observation["state"] = observation["state"].unsqueeze(0) + # TODO(rcadene): remove unsqueeze hack to add bsize=1 + observation["image"] = observation["image"].unsqueeze(0) + observation["state"] = observation["state"].unsqueeze(0) obs_dict = { - # TODO(rcadene): hack to add temporal dim - "image": einops.rearrange(observation["image"], "b c h w -> b 1 c h w"), - "agent_pos": einops.rearrange(observation["state"], "b c -> b 1 c"), + "image": observation["image"], + "agent_pos": observation["state"], } out = self.diffusion.predict_action(obs_dict) diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 6f18816a..f136fa55 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -13,7 +13,7 @@ shape_meta: shape: [2] horizon: 16 -n_obs_steps: 1 # TODO(rcadene): before 2 +n_obs_steps: 2 n_action_steps: 8 n_latency_steps: 0 dataset_obs_steps: ${n_obs_steps}