Add aloha2_real, Add act_real, Fix vae=false, Add support for no state

This commit is contained in:
Remi Cadene 2024-05-30 12:06:57 +00:00 committed by Thomas Wolf
parent bd3111f28b
commit 5495d55cc7
5 changed files with 35 additions and 33 deletions

View File

@ -26,10 +26,11 @@ class ACTConfig:
Those are: `input_shapes` and 'output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- At least one key starting with "observation.image is required as an input.
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera
views. Right now we only support all images having the same shape.
- May optionally work without an "observation.state" key for the proprioceptive robot state.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views.
Right now we only support all images having the same shape.
- "action" is required as an output key.
Args:

View File

@ -200,12 +200,13 @@ class ACT(nn.Module):
self.config = config
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.use_input_state = "observation.state" in config.input_shapes
self.has_state = "observation.state" in config.input_shapes
self.latent_dim = config.latent_dim
if self.config.use_vae:
self.vae_encoder = ACTEncoder(config)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension.
if self.use_input_state:
if self.has_state:
self.vae_encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
@ -217,9 +218,7 @@ class ACT(nn.Module):
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension.
num_input_token_encoder = 1 + config.chunk_size
if self.use_input_state:
num_input_token_encoder += 1
num_input_token_encoder = 1 + 1 + config.chunk_size if self.has_state else 1 + config.chunk_size
self.register_buffer(
"vae_encoder_pos_enc",
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
@ -242,16 +241,16 @@ class ACT(nn.Module):
# Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels].
if self.use_input_state:
if self.has_state:
self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, config.dim_model, kernel_size=1
)
# Transformer encoder positional embeddings.
num_input_token_decoder = 2 if self.use_input_state else 1
num_input_token_decoder = 2 if self.has_state else 1
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
@ -299,12 +298,12 @@ class ACT(nn.Module):
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.use_input_state:
if self.has_state:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.use_input_state:
if self.has_state:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
@ -329,7 +328,7 @@ class ACT(nn.Module):
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
@ -351,12 +350,12 @@ class ACT(nn.Module):
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
# Get positional embeddings for robot state and latent.
if self.use_input_state:
if self.has_state:
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
# Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
encoder_in_feats = [latent_embed, robot_state_embed] if self.has_state else [latent_embed]
encoder_in = torch.cat(
[
torch.stack(encoder_in_feats, axis=0),

View File

@ -28,7 +28,10 @@ class DiffusionConfig:
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- A key starting with "observation.image is required as an input.
- At least one key starting with "observation.image is required as an input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views.
Right now we only support all images having the same shape.
- "action" is required as an output key.
Args:

13
lerobot/configs/env/aloha2_real.yaml vendored Normal file
View File

@ -0,0 +1,13 @@
# @package _global_
fps: 30
env:
name: dora
task: DoraAloha2-v0
state_dim: 14
action_dim: 14
fps: ${fps}
episode_length: 400
gym:
fps: ${fps}

View File

@ -1,21 +1,7 @@
# @package _global_
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
#
# Example of usage for training:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real \
# env=dora_aloha_real
# ```
seed: 1000
dataset_repo_id: lerobot/aloha_static_vinh_cup
dataset_repo_id: cadene/aloha_v2_static_dora_test
override_dataset_stats:
observation.images.cam_right_wrist:
@ -41,7 +27,7 @@ training:
eval_freq: -1
save_freq: 10000
log_freq: 100
save_checkpoint: true
save_model: true
batch_size: 8
lr: 1e-5