PushtEnv inheriates AbstractEnv, Improve factory Normalization

This commit is contained in:
Cadene 2024-03-11 14:05:23 +00:00
parent ebd5c786f1
commit bdd2c801bc
3 changed files with 19 additions and 29 deletions

View File

@ -106,7 +106,9 @@ def make_offline_buffer(
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
transform = NormalizeTransform(stats, in_keys, mode="min_max") # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
transform = NormalizeTransform(stats, in_keys, mode=normalization_mode)
offline_buffer.set_transform(transform) offline_buffer.set_transform(transform)
if not overwrite_sampler: if not overwrite_sampler:

View File

@ -11,39 +11,38 @@ from torchrl.data.tensor_specs import (
DiscreteTensorSpec, DiscreteTensorSpec,
UnboundedContinuousTensorSpec, UnboundedContinuousTensorSpec,
) )
from torchrl.envs import EnvBase
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
from lerobot.common.envs.abstract import AbstractEnv
from lerobot.common.utils import set_seed from lerobot.common.utils import set_seed
_has_gym = importlib.util.find_spec("gym") is not None _has_gym = importlib.util.find_spec("gym") is not None
class PushtEnv(EnvBase): class PushtEnv(AbstractEnv):
def __init__( def __init__(
self, self,
task="pusht",
frame_skip: int = 1, frame_skip: int = 1,
from_pixels: bool = False, from_pixels: bool = False,
pixels_only: bool = False, pixels_only: bool = False,
image_size=None, image_size=None,
seed=1337, seed=1337,
device="cpu", device="cpu",
num_prev_obs=0, num_prev_obs=1,
num_prev_action=0, num_prev_action=0,
): ):
super().__init__(device=device, batch_size=[]) super().__init__(
self.frame_skip = frame_skip task=task,
self.from_pixels = from_pixels frame_skip=frame_skip,
self.pixels_only = pixels_only from_pixels=from_pixels,
self.image_size = image_size pixels_only=pixels_only,
self.num_prev_obs = num_prev_obs image_size=image_size,
self.num_prev_action = num_prev_action seed=seed,
device=device,
if pixels_only: num_prev_obs=num_prev_obs,
assert from_pixels num_prev_action=num_prev_action,
if from_pixels: )
assert image_size
if not _has_gym: if not _has_gym:
raise ImportError("Cannot import gym.") raise ImportError("Cannot import gym.")
@ -56,16 +55,6 @@ class PushtEnv(EnvBase):
self._env = PushTImageEnv(render_size=self.image_size) self._env = PushTImageEnv(render_size=self.image_size)
self._make_spec()
self._current_seed = self.set_seed(seed)
if self.num_prev_obs > 0:
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
if self.num_prev_action > 0:
raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
def render(self, mode="rgb_array", width=384, height=384): def render(self, mode="rgb_array", width=384, height=384):
if width != height: if width != height:
raise NotImplementedError() raise NotImplementedError()

View File

@ -49,7 +49,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
self.model, self.optimizer = build_act_model_and_optimizer(cfg) self.model, self.optimizer = build_act_model_and_optimizer(cfg)
self.kl_weight = self.cfg.kl_weight self.kl_weight = self.cfg.kl_weight
logging.info(f"KL Weight {self.kl_weight}") logging.info(f"KL Weight {self.kl_weight}")
self.to(self.device) self.to(self.device)
def update(self, replay_buffer, step): def update(self, replay_buffer, step):
@ -156,7 +155,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# TODO(rcadene): remove unsqueeze hack to add bsize=1 # TODO(rcadene): remove unsqueeze hack to add bsize=1
observation["image"] = observation["image"].unsqueeze(0) observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0) # observation["state"] = observation["state"].unsqueeze(0)
# TODO(rcadene): remove hack # TODO(rcadene): remove hack
# add 1 camera dimension # add 1 camera dimension