From 49a3db9f2fdd95d468a4637317acbdaa58d3da53 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Thu, 30 May 2024 12:06:57 +0000 Subject: [PATCH] Add aloha2_real, Add act_real, Fix vae=false, Add support for no state --- .gitignore | 1 - .../common/policies/act/configuration_act.py | 26 +++-- lerobot/common/policies/act/modeling_act.py | 50 ++++++--- .../diffusion/configuration_diffusion.py | 26 +++-- .../policies/tdmpc/configuration_tdmpc.py | 9 ++ lerobot/configs/env/aloha2_real.yaml | 13 +++ lerobot/configs/policy/act_real.yaml | 101 ++++++++++++++++++ 7 files changed, 190 insertions(+), 36 deletions(-) create mode 100644 lerobot/configs/env/aloha2_real.yaml create mode 100644 lerobot/configs/policy/act_real.yaml diff --git a/.gitignore b/.gitignore index 5b73b9ad..4ccf404d 100644 --- a/.gitignore +++ b/.gitignore @@ -121,7 +121,6 @@ celerybeat.pid # Environments .env .venv -env/ venv/ ENV/ env.bak/ diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 95374f4d..82bc6d8e 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -25,6 +25,14 @@ class ACTConfig: The parameters you will most likely need to change are the ones which depend on the environment / sensors. Those are: `input_shapes` and 'output_shapes`. + Notes on the inputs and outputs: + - "observation.state" is required as an input key. + - At least one key starting with "observation.image is required as an input. + - If there are multiple keys beginning with "observation.image" they are treated as multiple camera + views. + Right now we only support all images having the same shape. + - "action" is required as an output key. + Args: n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the current step and additional steps going back). @@ -33,15 +41,15 @@ class ACTConfig: This should be no greater than the chunk size. For example, if the chunk size size 100, you may set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the environment, and throws the other 50 out. - input_shapes: A dictionary defining the shapes of the input data for the policy. - The key represents the input data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "observation.images.top" refers to an input from the - "top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. - Importantly, shapes doesn't include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. - The key represents the output data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "action" refers to an output shape of [14], indicating - 14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension. + input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents + the input data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], + indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't + include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents + the output data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. + Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), and the value specifies the normalization mode to apply. The two available modes are "mean_std" which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index eafe677b..141dc862 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -200,25 +200,28 @@ class ACT(nn.Module): self.config = config # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence]. # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). + self.has_state = "observation.state" in config.input_shapes + self.latent_dim = config.latent_dim if self.config.use_vae: self.vae_encoder = ACTEncoder(config) self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) # Projection layer for joint-space configuration to hidden dimension. - self.vae_encoder_robot_state_input_proj = nn.Linear( - config.input_shapes["observation.state"][0], config.dim_model - ) + if self.has_state: + self.vae_encoder_robot_state_input_proj = nn.Linear( + config.input_shapes["observation.state"][0], config.dim_model + ) # Projection layer for action (joint-space target) to hidden dimension. self.vae_encoder_action_input_proj = nn.Linear( - config.input_shapes["observation.state"][0], config.dim_model + config.output_shapes["action"][0], config.dim_model ) - self.latent_dim = config.latent_dim # Projection layer from the VAE encoder's output to the latent distribution's parameter space. self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2) # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch # dimension. + num_input_token_encoder = 1 + 1 + config.chunk_size if self.has_state else 1 + config.chunk_size self.register_buffer( "vae_encoder_pos_enc", - create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0), + create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0), ) # Backbone for image feature extraction. @@ -238,15 +241,17 @@ class ACT(nn.Module): # Transformer encoder input projections. The tokens will be structured like # [latent, robot_state, image_feature_map_pixels]. - self.encoder_robot_state_input_proj = nn.Linear( - config.input_shapes["observation.state"][0], config.dim_model - ) + if self.has_state: + self.encoder_robot_state_input_proj = nn.Linear( + config.input_shapes["observation.state"][0], config.dim_model + ) self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model) self.encoder_img_feat_input_proj = nn.Conv2d( backbone_model.fc.in_features, config.dim_model, kernel_size=1 ) # Transformer encoder positional embeddings. - self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model) + num_input_token_decoder = 2 if self.has_state else 1 + self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model) self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) # Transformer decoder. @@ -285,7 +290,7 @@ class ACT(nn.Module): "action" in batch ), "actions must be provided when using the variational objective in training mode." - batch_size = batch["observation.state"].shape[0] + batch_size = batch["observation.images"].shape[0] # Prepare the latent for input to the transformer encoder. if self.config.use_vae and "action" in batch: @@ -293,11 +298,16 @@ class ACT(nn.Module): cls_embed = einops.repeat( self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size ) # (B, 1, D) - robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze( - 1 - ) # (B, 1, D) + if self.has_state: + robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) + robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) - vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D) + + if self.has_state: + vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) + else: + vae_encoder_input = [cls_embed, action_embed] + vae_encoder_input = torch.cat(vae_encoder_input, axis=1) # Prepare fixed positional embedding. # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. @@ -317,6 +327,7 @@ class ACT(nn.Module): else: # When not using the VAE encoder, we set the latent to be all zeros. mu = log_sigma_x2 = None + # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to( batch["observation.state"].device ) @@ -326,8 +337,10 @@ class ACT(nn.Module): all_cam_features = [] all_cam_pos_embeds = [] images = batch["observation.images"] + for cam_index in range(images.shape[-4]): cam_features = self.backbone(images[:, cam_index])["feature_map"] + # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) all_cam_features.append(cam_features) @@ -337,13 +350,15 @@ class ACT(nn.Module): cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1) # Get positional embeddings for robot state and latent. - robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C) + if self.has_state: + robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C) latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C) # Stack encoder input and positional embeddings moving to (S, B, C). + encoder_in_feats = [latent_embed, robot_state_embed] if self.has_state else [latent_embed] encoder_in = torch.cat( [ - torch.stack([latent_embed, robot_state_embed], axis=0), + torch.stack(encoder_in_feats, axis=0), einops.rearrange(encoder_in, "b c h w -> (h w) b c"), ] ) @@ -357,6 +372,7 @@ class ACT(nn.Module): # Forward pass through the transformer modules. encoder_out = self.encoder(encoder_in, pos_embed=pos_embed) + # TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer decoder_in = torch.zeros( (self.config.chunk_size, batch_size, self.config.dim_model), dtype=pos_embed.dtype, diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 81ff5de7..48783d89 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -26,21 +26,29 @@ class DiffusionConfig: The parameters you will most likely need to change are the ones which depend on the environment / sensors. Those are: `input_shapes` and `output_shapes`. + Notes on the inputs and outputs: + - "observation.state" is required as an input key. + - At least one key starting with "observation.image is required as an input. + - If there are multiple keys beginning with "observation.image" they are treated as multiple camera + views. + Right now we only support all images having the same shape. + - "action" is required as an output key. + Args: n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the current step and additional steps going back). horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. n_action_steps: The number of action steps to run in the environment for one invocation of the policy. See `DiffusionPolicy.select_action` for more details. - input_shapes: A dictionary defining the shapes of the input data for the policy. - The key represents the input data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "observation.image" refers to an input from - a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. - Importantly, shapes doesnt include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. - The key represents the output data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "action" refers to an output shape of [14], indicating - 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. + input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents + the input data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], + indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't + include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents + the output data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. + Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), and the value specifies the normalization mode to apply. The two available modes are "mean_std" which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index cf76fb08..49485c39 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -31,6 +31,15 @@ class TDMPCConfig: n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google action repeats in Q-learning or ask your favorite chatbot) horizon: Horizon for model predictive control. + input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents + the input data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], + indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't + include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents + the output data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. + Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), and the value specifies the normalization mode to apply. The two available modes are "mean_std" which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a diff --git a/lerobot/configs/env/aloha2_real.yaml b/lerobot/configs/env/aloha2_real.yaml new file mode 100644 index 00000000..3053fc01 --- /dev/null +++ b/lerobot/configs/env/aloha2_real.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +fps: 30 + +env: + name: dora + task: DoraAloha2-v0 + state_dim: 14 + action_dim: 14 + fps: ${fps} + episode_length: 400 + gym: + fps: ${fps} diff --git a/lerobot/configs/policy/act_real.yaml b/lerobot/configs/policy/act_real.yaml new file mode 100644 index 00000000..b786160e --- /dev/null +++ b/lerobot/configs/policy/act_real.yaml @@ -0,0 +1,101 @@ +# @package _global_ + +seed: 1000 +dataset_repo_id: cadene/aloha_v2_static_dora_test + +override_dataset_stats: + observation.images.cam_right_wrist: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + observation.images.cam_left_wrist: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + observation.images.cam_high: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + observation.images.cam_low: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + +training: + offline_steps: 80000 + online_steps: 0 + eval_freq: -1 + save_freq: 10000 + log_freq: 100 + save_model: true + + batch_size: 8 + lr: 1e-5 + lr_backbone: 1e-5 + weight_decay: 1e-4 + grad_clip_norm: 10 + online_steps_between_rollouts: 1 + + delta_timestamps: + action: "[i / ${fps} for i in range(${policy.chunk_size})]" + +eval: + n_episodes: 50 + batch_size: 50 + +# See `configuration_act.py` for more details. +policy: + name: act + + # Input / output structure. + n_obs_steps: 1 + chunk_size: 100 # chunk_size + n_action_steps: 100 + + input_shapes: + # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.images.cam_right_wrist: [3, 480, 640] + observation.images.cam_left_wrist: [3, 480, 640] + observation.images.cam_high: [3, 480, 640] + observation.images.cam_low: [3, 480, 640] + observation.state: ["${env.state_dim}"] + output_shapes: + action: ["${env.action_dim}"] + + # Normalization / Unnormalization + input_normalization_modes: + observation.images.cam_right_wrist: mean_std + observation.images.cam_left_wrist: mean_std + observation.images.cam_high: mean_std + observation.images.cam_low: mean_std + observation.state: mean_std + output_normalization_modes: + action: mean_std + + # Architecture. + # Vision backbone. + vision_backbone: resnet18 + pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1 + replace_final_stride_with_dilation: false + # Transformer layers. + pre_norm: false + dim_model: 512 + n_heads: 8 + dim_feedforward: 3200 + feedforward_activation: relu + n_encoder_layers: 4 + # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code + # that means only the first layer is used. Here we match the original implementation by setting this to 1. + # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521. + n_decoder_layers: 1 + # VAE. + use_vae: true + latent_dim: 32 + n_vae_encoder_layers: 4 + + # Inference. + temporal_ensemble_momentum: null + + # Training and loss computation. + dropout: 0.1 + kl_weight: 10.0