Add obs queue to pusht, Set n_obs_steps=2 for diffusion (Not fully tested)
This commit is contained in:
parent
cbbed590a9
commit
0f2fa4d9ef
|
@ -1,4 +1,5 @@
|
|||
import importlib
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
@ -27,12 +28,16 @@ class PushtEnv(EnvBase):
|
|||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
num_prev_obs=1,
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
self.pixels_only = pixels_only
|
||||
self.image_size = image_size
|
||||
self.num_prev_obs = num_prev_obs
|
||||
self.num_prev_action = num_prev_action
|
||||
|
||||
if pixels_only:
|
||||
assert from_pixels
|
||||
|
@ -56,6 +61,12 @@ class PushtEnv(EnvBase):
|
|||
self._make_spec()
|
||||
self._current_seed = self.set_seed(seed)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
||||
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
|
||||
if self.num_prev_action > 0:
|
||||
self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
||||
|
||||
def render(self, mode="rgb_array", width=384, height=384):
|
||||
if width != height:
|
||||
raise NotImplementedError()
|
||||
|
@ -67,7 +78,8 @@ class PushtEnv(EnvBase):
|
|||
|
||||
def _format_raw_obs(self, raw_obs):
|
||||
if self.from_pixels:
|
||||
obs = {"image": torch.from_numpy(raw_obs["image"])}
|
||||
image = torch.from_numpy(raw_obs["image"])
|
||||
obs = {"image": image}
|
||||
|
||||
if not self.pixels_only:
|
||||
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32)
|
||||
|
@ -75,7 +87,6 @@ class PushtEnv(EnvBase):
|
|||
# TODO:
|
||||
obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
|
||||
|
||||
obs = TensorDict(obs, batch_size=[])
|
||||
return obs
|
||||
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
|
@ -87,9 +98,21 @@ class PushtEnv(EnvBase):
|
|||
raw_obs = self._env.reset()
|
||||
assert self._current_seed == self._env._seed
|
||||
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
# remove all previous observations
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue.clear()
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue.clear()
|
||||
|
||||
# copy the current observation n times
|
||||
obs = self._stack_prev_obs(obs)
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": self._format_raw_obs(raw_obs),
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
|
@ -98,6 +121,40 @@ class PushtEnv(EnvBase):
|
|||
raise NotImplementedError()
|
||||
return td
|
||||
|
||||
def _stack_prev_obs(self, obs):
|
||||
"""When the queue is empty, copy the current observation n times."""
|
||||
assert self.num_prev_obs > 0
|
||||
|
||||
def stack_update_queue(prev_obs_queue, obs, num_prev_obs):
|
||||
# get n most recent observations
|
||||
prev_obs = list(prev_obs_queue)[-num_prev_obs:]
|
||||
|
||||
# if not enough observations, copy the oldest observation until we obtain n observations
|
||||
if len(prev_obs) == 0:
|
||||
prev_obs = [obs] * num_prev_obs # queue is empty when env reset
|
||||
elif len(prev_obs) < num_prev_obs:
|
||||
prev_obs = [prev_obs[0] for _ in range(num_prev_obs - len(prev_obs))] + prev_obs
|
||||
|
||||
# stack n most recent observations with the current observation
|
||||
stacked_obs = torch.stack(prev_obs + [obs], dim=0)
|
||||
|
||||
# add current observation to the queue
|
||||
# automatically remove oldest observation when queue is full
|
||||
prev_obs_queue.appendleft(obs)
|
||||
|
||||
return stacked_obs
|
||||
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
stacked_obs["image"] = stack_update_queue(
|
||||
self._prev_obs_image_queue, obs["image"], self.num_prev_obs
|
||||
)
|
||||
if "state" in obs:
|
||||
stacked_obs["state"] = stack_update_queue(
|
||||
self._prev_obs_state_queue, obs["state"], self.num_prev_obs
|
||||
)
|
||||
return stacked_obs
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
td = tensordict
|
||||
# remove batch dim
|
||||
|
@ -109,9 +166,14 @@ class PushtEnv(EnvBase):
|
|||
raw_obs, reward, done, info = self._env.step(action)
|
||||
sum_reward += reward
|
||||
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
obs = self._stack_prev_obs(obs)
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": self._format_raw_obs(raw_obs),
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"reward": torch.tensor([sum_reward], dtype=torch.float32),
|
||||
# succes and done are true when coverage > self.success_threshold in env
|
||||
"done": torch.tensor([done], dtype=torch.bool),
|
||||
|
@ -124,14 +186,22 @@ class PushtEnv(EnvBase):
|
|||
def _make_spec(self):
|
||||
obs = {}
|
||||
if self.from_pixels:
|
||||
image_shape = (3, self.image_size, self.image_size)
|
||||
if self.num_prev_obs > 0:
|
||||
image_shape = (self.num_prev_obs, *image_shape)
|
||||
|
||||
obs["image"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=1,
|
||||
shape=(3, self.image_size, self.image_size),
|
||||
shape=image_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
if not self.pixels_only:
|
||||
state_shape = self._env.observation_space["agent_pos"].shape
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs, *state_shape)
|
||||
|
||||
obs["state"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=512,
|
||||
|
@ -141,6 +211,10 @@ class PushtEnv(EnvBase):
|
|||
)
|
||||
else:
|
||||
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||
state_shape = self._env.observation_space["observation"].shape
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs, *state_shape)
|
||||
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
# TODO:
|
||||
shape=self._env.observation_space["observation"].shape,
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import copy
|
||||
import time
|
||||
|
||||
import einops
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -101,15 +100,13 @@ class DiffusionPolicy(nn.Module):
|
|||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
# TODO(rcadene): remove unsqueeze hack...
|
||||
if observation["image"].ndim == 3:
|
||||
# TODO(rcadene): remove unsqueeze hack to add bsize=1
|
||||
observation["image"] = observation["image"].unsqueeze(0)
|
||||
observation["state"] = observation["state"].unsqueeze(0)
|
||||
|
||||
obs_dict = {
|
||||
# TODO(rcadene): hack to add temporal dim
|
||||
"image": einops.rearrange(observation["image"], "b c h w -> b 1 c h w"),
|
||||
"agent_pos": einops.rearrange(observation["state"], "b c -> b 1 c"),
|
||||
"image": observation["image"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
out = self.diffusion.predict_action(obs_dict)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ shape_meta:
|
|||
shape: [2]
|
||||
|
||||
horizon: 16
|
||||
n_obs_steps: 1 # TODO(rcadene): before 2
|
||||
n_obs_steps: 2
|
||||
n_action_steps: 8
|
||||
n_latency_steps: 0
|
||||
dataset_obs_steps: ${n_obs_steps}
|
||||
|
|
Loading…
Reference in New Issue