Add obs queue to pusht, Set n_obs_steps=2 for diffusion (Not fully tested)

This commit is contained in:
Remi Cadene 2024-03-03 13:21:31 +00:00
parent cbbed590a9
commit 0f2fa4d9ef
3 changed files with 85 additions and 14 deletions

View File

@ -1,4 +1,5 @@
import importlib
from collections import deque
from typing import Optional
import torch
@ -27,12 +28,16 @@ class PushtEnv(EnvBase):
image_size=None,
seed=1337,
device="cpu",
num_prev_obs=1,
num_prev_action=0,
):
super().__init__(device=device, batch_size=[])
self.frame_skip = frame_skip
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
self.num_prev_obs = num_prev_obs
self.num_prev_action = num_prev_action
if pixels_only:
assert from_pixels
@ -56,6 +61,12 @@ class PushtEnv(EnvBase):
self._make_spec()
self._current_seed = self.set_seed(seed)
if self.num_prev_obs > 0:
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
if self.num_prev_action > 0:
self._prev_action_queue = deque(maxlen=self.num_prev_action)
def render(self, mode="rgb_array", width=384, height=384):
if width != height:
raise NotImplementedError()
@ -67,7 +78,8 @@ class PushtEnv(EnvBase):
def _format_raw_obs(self, raw_obs):
if self.from_pixels:
obs = {"image": torch.from_numpy(raw_obs["image"])}
image = torch.from_numpy(raw_obs["image"])
obs = {"image": image}
if not self.pixels_only:
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32)
@ -75,7 +87,6 @@ class PushtEnv(EnvBase):
# TODO:
obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
obs = TensorDict(obs, batch_size=[])
return obs
def _reset(self, tensordict: Optional[TensorDict] = None):
@ -87,9 +98,21 @@ class PushtEnv(EnvBase):
raw_obs = self._env.reset()
assert self._current_seed == self._env._seed
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
# remove all previous observations
if "image" in obs:
self._prev_obs_image_queue.clear()
if "state" in obs:
self._prev_obs_state_queue.clear()
# copy the current observation n times
obs = self._stack_prev_obs(obs)
td = TensorDict(
{
"observation": self._format_raw_obs(raw_obs),
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
@ -98,6 +121,40 @@ class PushtEnv(EnvBase):
raise NotImplementedError()
return td
def _stack_prev_obs(self, obs):
"""When the queue is empty, copy the current observation n times."""
assert self.num_prev_obs > 0
def stack_update_queue(prev_obs_queue, obs, num_prev_obs):
# get n most recent observations
prev_obs = list(prev_obs_queue)[-num_prev_obs:]
# if not enough observations, copy the oldest observation until we obtain n observations
if len(prev_obs) == 0:
prev_obs = [obs] * num_prev_obs # queue is empty when env reset
elif len(prev_obs) < num_prev_obs:
prev_obs = [prev_obs[0] for _ in range(num_prev_obs - len(prev_obs))] + prev_obs
# stack n most recent observations with the current observation
stacked_obs = torch.stack(prev_obs + [obs], dim=0)
# add current observation to the queue
# automatically remove oldest observation when queue is full
prev_obs_queue.appendleft(obs)
return stacked_obs
stacked_obs = {}
if "image" in obs:
stacked_obs["image"] = stack_update_queue(
self._prev_obs_image_queue, obs["image"], self.num_prev_obs
)
if "state" in obs:
stacked_obs["state"] = stack_update_queue(
self._prev_obs_state_queue, obs["state"], self.num_prev_obs
)
return stacked_obs
def _step(self, tensordict: TensorDict):
td = tensordict
# remove batch dim
@ -109,9 +166,14 @@ class PushtEnv(EnvBase):
raw_obs, reward, done, info = self._env.step(action)
sum_reward += reward
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
obs = self._stack_prev_obs(obs)
td = TensorDict(
{
"observation": self._format_raw_obs(raw_obs),
"observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool),
@ -124,14 +186,22 @@ class PushtEnv(EnvBase):
def _make_spec(self):
obs = {}
if self.from_pixels:
image_shape = (3, self.image_size, self.image_size)
if self.num_prev_obs > 0:
image_shape = (self.num_prev_obs, *image_shape)
obs["image"] = BoundedTensorSpec(
low=0,
high=1,
shape=(3, self.image_size, self.image_size),
shape=image_shape,
dtype=torch.float32,
device=self.device,
)
if not self.pixels_only:
state_shape = self._env.observation_space["agent_pos"].shape
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs, *state_shape)
obs["state"] = BoundedTensorSpec(
low=0,
high=512,
@ -141,6 +211,10 @@ class PushtEnv(EnvBase):
)
else:
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
state_shape = self._env.observation_space["observation"].shape
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs, *state_shape)
obs["state"] = UnboundedContinuousTensorSpec(
# TODO:
shape=self._env.observation_space["observation"].shape,

View File

@ -1,7 +1,6 @@
import copy
import time
import einops
import hydra
import torch
import torch.nn as nn
@ -101,15 +100,13 @@ class DiffusionPolicy(nn.Module):
# TODO(rcadene): remove unused step_count
del step_count
# TODO(rcadene): remove unsqueeze hack...
if observation["image"].ndim == 3:
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)
# TODO(rcadene): remove unsqueeze hack to add bsize=1
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)
obs_dict = {
# TODO(rcadene): hack to add temporal dim
"image": einops.rearrange(observation["image"], "b c h w -> b 1 c h w"),
"agent_pos": einops.rearrange(observation["state"], "b c -> b 1 c"),
"image": observation["image"],
"agent_pos": observation["state"],
}
out = self.diffusion.predict_action(obs_dict)

View File

@ -13,7 +13,7 @@ shape_meta:
shape: [2]
horizon: 16
n_obs_steps: 1 # TODO(rcadene): before 2
n_obs_steps: 2
n_action_steps: 8
n_latency_steps: 0
dataset_obs_steps: ${n_obs_steps}