diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py index 4956530a..9c652c0a 100644 --- a/lerobot/common/policies/abstract.py +++ b/lerobot/common/policies/abstract.py @@ -1,15 +1,20 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod from collections import deque import torch from torch import Tensor, nn -class AbstractPolicy(nn.Module): +class AbstractPolicy(nn.Module, ABC): + """Base policy which all policies should be derived from. + + The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its + documentation for more information. + """ + @abstractmethod def update(self, replay_buffer, step): """One step of the policy's learning algorithm.""" - pass def save(self, fp): torch.save(self.state_dict(), fp) @@ -24,7 +29,6 @@ class AbstractPolicy(nn.Module): Should return a (batch_size, n_action_steps, *) tensor of actions. """ - pass def forward(self, *args, **kwargs): """Inference step that makes multi-step policies compatible with their single-step environments. diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index e87f155e..e0499cdb 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -153,10 +153,6 @@ class ActionChunkingTransformerPolicy(AbstractPolicy): self.eval() - # TODO(rcadene): remove unsqueeze hack to add bsize=1 - observation["image", "top"] = observation["image", "top"].unsqueeze(0) - # observation["state"] = observation["state"].unsqueeze(0) - # TODO(rcadene): remove hack # add 1 camera dimension observation["image", "top"] = observation["image", "top"].unsqueeze(1) @@ -180,11 +176,8 @@ class ActionChunkingTransformerPolicy(AbstractPolicy): # exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1) # raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) - # remove bsize=1 - action = action.squeeze(0) - # take first predicted action or n first actions - action = action[0] if self.n_action_steps == 1 else action[: self.n_action_steps] + action = action[: self.n_action_steps] return action def _forward(self, qpos, image, actions=None, is_pad=None): diff --git a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py index 6a1d3c0d..94dc6f49 100644 --- a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py +++ b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py @@ -1,37 +1,15 @@ import copy -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Tuple, Union -import timm import torch import torch.nn as nn import torchvision -from robomimic.models.base_nets import SpatialSoftmax from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules -class RgbEncoder(nn.Module): - """Following `VisualCore` from Robomimic 0.2.0.""" - - def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32): - """ - resnet_name: a timm model name. - pretrained: whether to use timm pretrained weights. - num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image). - """ - super().__init__() - self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="") - # Figure out the feature map shape. - with torch.inference_mode(): - feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:]) - self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints) - - def forward(self, x): - return torch.flatten(self.pool(self.backbone(x)), start_dim=1) - - class MultiImageObsEncoder(ModuleAttrMixin): def __init__( self, @@ -46,7 +24,7 @@ class MultiImageObsEncoder(ModuleAttrMixin): share_rgb_model: bool = False, # renormalize rgb input with imagenet normalization # assuming input in [0,1] - norm_mean_std: Optional[tuple[float, float]] = None, + imagenet_norm: bool = False, ): """ Assumes rgb input: B,C,H,W @@ -120,9 +98,10 @@ class MultiImageObsEncoder(ModuleAttrMixin): this_normalizer = torchvision.transforms.CenterCrop(size=(h, w)) # configure normalizer this_normalizer = nn.Identity() - if norm_mean_std is not None: + if imagenet_norm: + # TODO(rcadene): move normalizer to dataset and env this_normalizer = torchvision.transforms.Normalize( - mean=norm_mean_std[0], std=norm_mean_std[1] + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer) diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index e779596c..db004a71 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -7,7 +7,7 @@ import torch from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler -from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder +from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder class DiffusionPolicy(AbstractPolicy): @@ -38,7 +38,7 @@ class DiffusionPolicy(AbstractPolicy): self.cfg = cfg noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) - rgb_model = RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg_rgb_model) + rgb_model = hydra.utils.instantiate(cfg_rgb_model) obs_encoder = MultiImageObsEncoder( rgb_model=rgb_model, **cfg_obs_encoder, diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index 48955459..4c104bcd 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -128,11 +128,6 @@ class TDMPC(AbstractPolicy): def select_action(self, observation, step_count): t0 = step_count.item() == 0 - # TODO(rcadene): remove unsqueeze hack... - if observation["image"].ndim == 3: - observation["image"] = observation["image"].unsqueeze(0) - observation["state"] = observation["state"].unsqueeze(0) - obs = { # TODO(rcadene): remove contiguous hack... "rgb": observation["image"].contiguous(), @@ -149,7 +144,7 @@ class TDMPC(AbstractPolicy): if self.cfg.mpc: a = self.plan(z, t0=t0, step=step) else: - a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0) + a = self.model.pi(z, self.cfg.min_std * self.model.training) return a @torch.no_grad() diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 7de44102..0dae5056 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -42,8 +42,8 @@ policy: num_inference_steps: 100 obs_as_global_cond: ${obs_as_global_cond} # crop_shape: null - diffusion_step_embed_dim: 128 - down_dims: [512, 1024, 2048] + diffusion_step_embed_dim: 256 # before 128 + down_dims: [256, 512, 1024] # before [512, 1024, 2048] kernel_size: 5 n_groups: 8 cond_predict_scale: True @@ -81,12 +81,12 @@ obs_encoder: # random_crop: True use_group_norm: True share_rgb_model: False - norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs) + imagenet_norm: True rgb_model: - model_name: resnet18 - pretrained: false - num_keypoints: 32 + _target_: lerobot.common.policies.diffusion.pytorch_utils.get_resnet + name: resnet18 + weights: null ema: _target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel @@ -109,13 +109,13 @@ training: debug: False resume: True # optimization - lr_scheduler: cosine - lr_warmup_steps: 500 - num_epochs: 500 + # lr_scheduler: cosine + # lr_warmup_steps: 500 + num_epochs: 8000 # gradient_accumulate_every: 1 # EMA destroys performance when used with BatchNorm # replace BatchNorm with GroupNorm. - use_ema: True + # use_ema: True freeze_encoder: False # training loop control # in epochs diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 839c12bb..7cfb796a 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -135,8 +135,8 @@ def eval(cfg: dict, out_dir=None): cfg.rollout_batch_size, create_env_fn=make_env, create_env_kwargs=[ - {"cfg": cfg, "seed": s, "transform": offline_buffer.transform} - for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) + {"cfg": cfg, "seed": env_seed, "transform": offline_buffer.transform} + for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) ], ) diff --git a/tests/test_policies.py b/tests/test_policies.py index 7d9a4dce..92324485 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -84,6 +84,9 @@ def test_abstract_policy_forward(): self.n_action_steps = n_action_steps self.n_policy_invocations = 0 + def update(self): + pass + def select_action(self): self.n_policy_invocations += 1 return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0)