diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 29f40bc6..63bde225 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -106,7 +106,9 @@ def make_offline_buffer( stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) - transform = NormalizeTransform(stats, in_keys, mode="min_max") + # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std + normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" + transform = NormalizeTransform(stats, in_keys, mode=normalization_mode) offline_buffer.set_transform(transform) if not overwrite_sampler: diff --git a/lerobot/common/envs/pusht/env.py b/lerobot/common/envs/pusht/env.py index ff49f791..5c1a19f4 100644 --- a/lerobot/common/envs/pusht/env.py +++ b/lerobot/common/envs/pusht/env.py @@ -11,39 +11,38 @@ from torchrl.data.tensor_specs import ( DiscreteTensorSpec, UnboundedContinuousTensorSpec, ) -from torchrl.envs import EnvBase from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform +from lerobot.common.envs.abstract import AbstractEnv from lerobot.common.utils import set_seed _has_gym = importlib.util.find_spec("gym") is not None -class PushtEnv(EnvBase): +class PushtEnv(AbstractEnv): def __init__( self, + task="pusht", frame_skip: int = 1, from_pixels: bool = False, pixels_only: bool = False, image_size=None, seed=1337, device="cpu", - num_prev_obs=0, + num_prev_obs=1, num_prev_action=0, ): - super().__init__(device=device, batch_size=[]) - self.frame_skip = frame_skip - self.from_pixels = from_pixels - self.pixels_only = pixels_only - self.image_size = image_size - self.num_prev_obs = num_prev_obs - self.num_prev_action = num_prev_action - - if pixels_only: - assert from_pixels - if from_pixels: - assert image_size - + super().__init__( + task=task, + frame_skip=frame_skip, + from_pixels=from_pixels, + pixels_only=pixels_only, + image_size=image_size, + seed=seed, + device=device, + num_prev_obs=num_prev_obs, + num_prev_action=num_prev_action, + ) if not _has_gym: raise ImportError("Cannot import gym.") @@ -56,16 +55,6 @@ class PushtEnv(EnvBase): self._env = PushTImageEnv(render_size=self.image_size) - self._make_spec() - self._current_seed = self.set_seed(seed) - - if self.num_prev_obs > 0: - self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs) - self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs) - if self.num_prev_action > 0: - raise NotImplementedError() - # self._prev_action_queue = deque(maxlen=self.num_prev_action) - def render(self, mode="rgb_array", width=384, height=384): if width != height: raise NotImplementedError() diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index 7928b3ab..0270ee6a 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -49,7 +49,6 @@ class ActionChunkingTransformerPolicy(nn.Module): self.model, self.optimizer = build_act_model_and_optimizer(cfg) self.kl_weight = self.cfg.kl_weight logging.info(f"KL Weight {self.kl_weight}") - self.to(self.device) def update(self, replay_buffer, step): @@ -156,7 +155,7 @@ class ActionChunkingTransformerPolicy(nn.Module): # TODO(rcadene): remove unsqueeze hack to add bsize=1 observation["image"] = observation["image"].unsqueeze(0) - observation["state"] = observation["state"].unsqueeze(0) + # observation["state"] = observation["state"].unsqueeze(0) # TODO(rcadene): remove hack # add 1 camera dimension