diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index ed7854ff..c22ae698 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -59,96 +59,10 @@ def make_dataset( transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), ) stats = compute_or_load_stats(stats_dataset) + # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" - # # TODO(now): These stats are needed to use their pretrained model for sim_transfer_cube_human. - # # (Pdb) stats['observation']['state']['mean'] - # # tensor([-0.0071, -0.6293, 1.0351, -0.0517, -0.4642, -0.0754, 0.4751, -0.0373, - # # -0.3324, 0.9034, -0.2258, -0.3127, -0.2412, 0.6866]) - # stats["observation", "state", "mean"] = torch.tensor( - # [ - # -0.00740268, - # -0.63187766, - # 1.0356655, - # -0.05027218, - # -0.46199223, - # -0.07467502, - # 0.47467607, - # -0.03615446, - # -0.33203387, - # 0.9038929, - # -0.22060776, - # -0.31011587, - # -0.23484458, - # 0.6842416, - # ] - # ) - # # (Pdb) stats['observation']['state']['std'] - # # tensor([0.0022, 0.0520, 0.0291, 0.0092, 0.0267, 0.0145, 0.0563, 0.0179, 0.0494, - # # 0.0326, 0.0476, 0.0535, 0.0956, 0.0513]) - # stats["observation", "state", "std"] = torch.tensor( - # [ - # 0.01219023, - # 0.2975381, - # 0.16728032, - # 0.04733803, - # 0.1486037, - # 0.08788499, - # 0.31752336, - # 0.1049916, - # 0.27933604, - # 0.18094037, - # 0.26604933, - # 0.30466506, - # 0.5298686, - # 0.25505227, - # ] - # ) - # # (Pdb) stats['action']['mean'] - # # tensor([-0.0075, -0.6346, 1.0353, -0.0465, -0.4686, -0.0738, 0.3723, -0.0396, - # # -0.3184, 0.8991, -0.2065, -0.3182, -0.2338, 0.5593]) - # stats["action"]["mean"] = torch.tensor( - # [ - # -0.00756444, - # -0.6281845, - # 1.0312834, - # -0.04664314, - # -0.47211358, - # -0.074527, - # 0.37389806, - # -0.03718753, - # -0.3261143, - # 0.8997205, - # -0.21371077, - # -0.31840396, - # -0.23360962, - # 0.551947, - # ] - # ) - # # (Pdb) stats['action']['std'] - # # tensor([0.0023, 0.0514, 0.0290, 0.0086, 0.0263, 0.0143, 0.0593, 0.0185, 0.0510, - # # 0.0328, 0.0478, 0.0531, 0.0945, 0.0794]) - # stats["action"]["std"] = torch.tensor( - # [ - # 0.01252818, - # 0.2957442, - # 0.16701928, - # 0.04584508, - # 0.14833844, - # 0.08763024, - # 0.30665937, - # 0.10600077, - # 0.27572668, - # 0.1805853, - # 0.26304692, - # 0.30708534, - # 0.5305411, - # 0.38381037, - # ] - # ) - # transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode)) # noqa: F821 - transforms = v2.Compose( [ # TODO(rcadene): we need to do something about image_keys diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py index beebd8ac..6dc72bef 100644 --- a/lerobot/common/policies/abstract.py +++ b/lerobot/common/policies/abstract.py @@ -4,3 +4,79 @@ import torch from torch import Tensor, nn +class AbstractPolicy(nn.Module): + """Base policy which all policies should be derived from. + + The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its + documentation for more information. + + Note: + When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: + 1. set the required class attributes: + - for classes inheriting from `AbstractDataset`: `available_datasets` + - for classes inheriting from `AbstractEnv`: `name`, `available_tasks` + - for classes inheriting from `AbstractPolicy`: `name` + 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) + 3. update variables in `tests/test_available.py` by importing your new class + """ + + name: str | None = None # same name should be used to instantiate the policy in factory.py + + def __init__(self, n_action_steps: int | None): + """ + n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single + action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then + adds that dimension. + """ + super().__init__() + assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute." + self.n_action_steps = n_action_steps + self.clear_action_queue() + + def update(self, replay_buffer, step): + """One step of the policy's learning algorithm.""" + raise NotImplementedError("Abstract method") + + def save(self, fp): + torch.save(self.state_dict(), fp) + + def load(self, fp): + d = torch.load(fp) + self.load_state_dict(d) + + def select_actions(self, observation) -> Tensor: + """Select an action (or trajectory of actions) based on an observation during rollout. + + If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of + actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions. + """ + raise NotImplementedError("Abstract method") + + def clear_action_queue(self): + """This should be called whenever the environment is reset.""" + if self.n_action_steps is not None: + self._action_queue = deque([], maxlen=self.n_action_steps) + + def forward(self, *args, **kwargs) -> Tensor: + """Inference step that makes multi-step policies compatible with their single-step environments. + + WARNING: In general, this should not be overriden. + + Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit + into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an + observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment + observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that + the subclass doesn't have to. + + This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made: + 1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is + the action trajectory horizon and * is the action dimensions. + 2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined. + """ + if self.n_action_steps is None: + return self.select_actions(*args, **kwargs) + if len(self._action_queue) == 0: + # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape + # (n_action_steps, batch_size, *), hence the transpose. + self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1)) + return self._action_queue.popleft() diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index 75d5ca0e..834dd9b2 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -67,7 +67,7 @@ class ActionChunkingTransformerPolicy(nn.Module): def __init__(self, cfg, device, n_action_steps=1): """ - TODO(alexander-soare): Add documentation for all parameters. + TODO(alexander-soare): Add documentation for all parameters once we have model configs established. """ super().__init__() if getattr(cfg, "n_obs_steps", 1) != 1: @@ -109,6 +109,9 @@ class ActionChunkingTransformerPolicy(nn.Module): ) # Backbone for image feature extraction. + self.image_normalizer = transforms.Normalize( + mean=cfg.image_normalization.mean, std=cfg.image_normalization.std + ) backbone_model = getattr(torchvision.models, cfg.backbone)( replace_stride_with_dilation=[False, False, cfg.dilation], pretrained=cfg.pretrained_backbone, @@ -275,9 +278,7 @@ class ActionChunkingTransformerPolicy(nn.Module): return info def forward(self, batch: dict[str, Tensor], return_loss: bool = False): - # TODO(now): Maybe this shouldn't be here? - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - images = normalize(batch["observation.images.top"]) + images = self.image_normalizer(batch["observation.images.top"]) if return_loss: # training time actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward( diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index 93e5ba5d..9785358b 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -151,7 +151,6 @@ class DiffusionPolicy(nn.Module): self.diffusion.train() - data_s = time.time() - start_time loss = self.diffusion.compute_loss(batch) loss.backward() @@ -172,7 +171,6 @@ class DiffusionPolicy(nn.Module): "loss": loss.item(), "grad_norm": float(grad_norm), "lr": self.lr_scheduler.get_last_lr()[0], - "data_s": data_s, "update_s": time.time() - start_time, } diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index 80f50003..cd34d115 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -1,6 +1,6 @@ # @package _global_ -offline_steps: 2000 +offline_steps: 80000 online_steps: 0 eval_episodes: 1 @@ -54,8 +54,12 @@ policy: temporal_agg: false - state_dim: ??? - action_dim: ??? + state_dim: 14 + action_dim: 14 + + image_normalization: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] delta_timestamps: observation.images.top: [0.0] diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index b43f4ed1..72966211 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -86,7 +86,9 @@ def eval_policy( def maybe_render_frame(env): if save_video: # noqa: B023 if return_first_video: - visu = env.envs[0].render(mode="visualization") + # TODO(now): Put mode back in. + visu = env.envs[0].render() + # visu = env.envs[0].render(mode="visualization") visu = visu[None, ...] # add batch dim else: # TODO(now): Put mode back in. diff --git a/scripts/convert_act_weights.py b/scripts/convert_act_weights.py deleted file mode 100644 index d5e38796..00000000 --- a/scripts/convert_act_weights.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch - -from lerobot.common.policies.factory import make_policy -from lerobot.common.utils import init_hydra_config - -cfg = init_hydra_config( - "/home/alexander/Projects/lerobot/outputs/train/act_aloha_sim_transfer_cube_human/.hydra/config.yaml" -) - -policy = make_policy(cfg) - -state_dict = torch.load("/home/alexander/Projects/act/outputs/sim_transfer_cube_human_vae/policy_last.ckpt") - -# Remove keys based on what they start with. - -start_removals = [ - # There is a bug that means the pretrained model doesn't even use the final decoder layers. - *[f"model.transformer.decoder.layers.{i}" for i in range(1, 7)], - "model.is_pad_head.", -] - -for to_remove in start_removals: - for k in list(state_dict.keys()): - if k.startswith(to_remove): - del state_dict[k] - - -# Replace keys based on what they start with. - -start_replacements = [ - ("model.", ""), - ("query_embed.weight", "pos_embed.weight"), - ("pos_table", "vae_encoder_pos_enc"), - ("pos_embed.weight", "decoder_pos_embed.weight"), - ("encoder.", "vae_encoder."), - ("encoder_action_proj.", "vae_encoder_action_input_proj."), - ("encoder_joint_proj.", "vae_encoder_robot_state_input_proj."), - ("latent_proj.", "vae_encoder_latent_output_proj."), - ("latent_proj.", "vae_encoder_latent_output_proj."), - ("input_proj.", "encoder_img_feat_input_proj."), - ("input_proj_robot_state", "encoder_robot_state_input_proj"), - ("latent_out_proj.", "encoder_latent_input_proj."), - ("transformer.encoder.", "encoder."), - ("transformer.decoder.", "decoder."), - ("backbones.0.0.body.", "backbone."), - ("additional_pos_embed.weight", "encoder_robot_and_latent_pos_embed.weight"), - ("cls_embed.weight", "vae_encoder_cls_embed.weight"), -] - -for to_replace, replace_with in start_replacements: - for k in list(state_dict.keys()): - if k.startswith(to_replace): - k_ = replace_with + k.removeprefix(to_replace) - state_dict[k_] = state_dict[k] - del state_dict[k] - - -missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False) - -if len(missing_keys) != 0: - print("MISSING KEYS") - print(missing_keys) -if len(unexpected_keys) != 0: - print("UNEXPECTED KEYS") - print(unexpected_keys) - -# if len(missing_keys) != 0 or len(unexpected_keys) != 0: -# print("Failed due to mismatch in state dicts.") -# exit() - -policy.save("/tmp/weights.pth")