Add aloha2_real, Add act_real, Fix vae=false, Add support for no state
This commit is contained in:
parent
57fb5fe8a6
commit
49a3db9f2f
|
@ -121,7 +121,6 @@ celerybeat.pid
|
||||||
# Environments
|
# Environments
|
||||||
.env
|
.env
|
||||||
.venv
|
.venv
|
||||||
env/
|
|
||||||
venv/
|
venv/
|
||||||
ENV/
|
ENV/
|
||||||
env.bak/
|
env.bak/
|
||||||
|
|
|
@ -25,6 +25,14 @@ class ACTConfig:
|
||||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||||
Those are: `input_shapes` and 'output_shapes`.
|
Those are: `input_shapes` and 'output_shapes`.
|
||||||
|
|
||||||
|
Notes on the inputs and outputs:
|
||||||
|
- "observation.state" is required as an input key.
|
||||||
|
- At least one key starting with "observation.image is required as an input.
|
||||||
|
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
||||||
|
views.
|
||||||
|
Right now we only support all images having the same shape.
|
||||||
|
- "action" is required as an output key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||||
current step and additional steps going back).
|
current step and additional steps going back).
|
||||||
|
@ -33,15 +41,15 @@ class ACTConfig:
|
||||||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||||
environment, and throws the other 50 out.
|
environment, and throws the other 50 out.
|
||||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
The key represents the input data name, and the value is a list indicating the dimensions
|
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
of the corresponding data. For example, "observation.images.top" refers to an input from the
|
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||||
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||||
Importantly, shapes doesn't include batch dimension or temporal dimension.
|
include batch dimension or temporal dimension.
|
||||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||||
The key represents the output data name, and the value is a list indicating the dimensions
|
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||||
14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension.
|
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||||
|
|
|
@ -200,25 +200,28 @@ class ACT(nn.Module):
|
||||||
self.config = config
|
self.config = config
|
||||||
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||||
|
self.has_state = "observation.state" in config.input_shapes
|
||||||
|
self.latent_dim = config.latent_dim
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
self.vae_encoder = ACTEncoder(config)
|
self.vae_encoder = ACTEncoder(config)
|
||||||
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
||||||
# Projection layer for joint-space configuration to hidden dimension.
|
# Projection layer for joint-space configuration to hidden dimension.
|
||||||
|
if self.has_state:
|
||||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
config.input_shapes["observation.state"][0], config.dim_model
|
||||||
)
|
)
|
||||||
# Projection layer for action (joint-space target) to hidden dimension.
|
# Projection layer for action (joint-space target) to hidden dimension.
|
||||||
self.vae_encoder_action_input_proj = nn.Linear(
|
self.vae_encoder_action_input_proj = nn.Linear(
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
config.output_shapes["action"][0], config.dim_model
|
||||||
)
|
)
|
||||||
self.latent_dim = config.latent_dim
|
|
||||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
|
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
|
||||||
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
||||||
# dimension.
|
# dimension.
|
||||||
|
num_input_token_encoder = 1 + 1 + config.chunk_size if self.has_state else 1 + config.chunk_size
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"vae_encoder_pos_enc",
|
"vae_encoder_pos_enc",
|
||||||
create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0),
|
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Backbone for image feature extraction.
|
# Backbone for image feature extraction.
|
||||||
|
@ -238,6 +241,7 @@ class ACT(nn.Module):
|
||||||
|
|
||||||
# Transformer encoder input projections. The tokens will be structured like
|
# Transformer encoder input projections. The tokens will be structured like
|
||||||
# [latent, robot_state, image_feature_map_pixels].
|
# [latent, robot_state, image_feature_map_pixels].
|
||||||
|
if self.has_state:
|
||||||
self.encoder_robot_state_input_proj = nn.Linear(
|
self.encoder_robot_state_input_proj = nn.Linear(
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
config.input_shapes["observation.state"][0], config.dim_model
|
||||||
)
|
)
|
||||||
|
@ -246,7 +250,8 @@ class ACT(nn.Module):
|
||||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||||
)
|
)
|
||||||
# Transformer encoder positional embeddings.
|
# Transformer encoder positional embeddings.
|
||||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model)
|
num_input_token_decoder = 2 if self.has_state else 1
|
||||||
|
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
|
||||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||||
|
|
||||||
# Transformer decoder.
|
# Transformer decoder.
|
||||||
|
@ -285,7 +290,7 @@ class ACT(nn.Module):
|
||||||
"action" in batch
|
"action" in batch
|
||||||
), "actions must be provided when using the variational objective in training mode."
|
), "actions must be provided when using the variational objective in training mode."
|
||||||
|
|
||||||
batch_size = batch["observation.state"].shape[0]
|
batch_size = batch["observation.images"].shape[0]
|
||||||
|
|
||||||
# Prepare the latent for input to the transformer encoder.
|
# Prepare the latent for input to the transformer encoder.
|
||||||
if self.config.use_vae and "action" in batch:
|
if self.config.use_vae and "action" in batch:
|
||||||
|
@ -293,11 +298,16 @@ class ACT(nn.Module):
|
||||||
cls_embed = einops.repeat(
|
cls_embed = einops.repeat(
|
||||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||||
) # (B, 1, D)
|
) # (B, 1, D)
|
||||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
|
if self.has_state:
|
||||||
1
|
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||||
) # (B, 1, D)
|
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||||
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
|
||||||
|
if self.has_state:
|
||||||
|
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||||
|
else:
|
||||||
|
vae_encoder_input = [cls_embed, action_embed]
|
||||||
|
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
|
||||||
|
|
||||||
# Prepare fixed positional embedding.
|
# Prepare fixed positional embedding.
|
||||||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
||||||
|
@ -317,6 +327,7 @@ class ACT(nn.Module):
|
||||||
else:
|
else:
|
||||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||||
mu = log_sigma_x2 = None
|
mu = log_sigma_x2 = None
|
||||||
|
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||||
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
|
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
|
||||||
batch["observation.state"].device
|
batch["observation.state"].device
|
||||||
)
|
)
|
||||||
|
@ -326,8 +337,10 @@ class ACT(nn.Module):
|
||||||
all_cam_features = []
|
all_cam_features = []
|
||||||
all_cam_pos_embeds = []
|
all_cam_pos_embeds = []
|
||||||
images = batch["observation.images"]
|
images = batch["observation.images"]
|
||||||
|
|
||||||
for cam_index in range(images.shape[-4]):
|
for cam_index in range(images.shape[-4]):
|
||||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||||
|
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||||
all_cam_features.append(cam_features)
|
all_cam_features.append(cam_features)
|
||||||
|
@ -337,13 +350,15 @@ class ACT(nn.Module):
|
||||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||||
|
|
||||||
# Get positional embeddings for robot state and latent.
|
# Get positional embeddings for robot state and latent.
|
||||||
|
if self.has_state:
|
||||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
||||||
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
||||||
|
|
||||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
# Stack encoder input and positional embeddings moving to (S, B, C).
|
||||||
|
encoder_in_feats = [latent_embed, robot_state_embed] if self.has_state else [latent_embed]
|
||||||
encoder_in = torch.cat(
|
encoder_in = torch.cat(
|
||||||
[
|
[
|
||||||
torch.stack([latent_embed, robot_state_embed], axis=0),
|
torch.stack(encoder_in_feats, axis=0),
|
||||||
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -357,6 +372,7 @@ class ACT(nn.Module):
|
||||||
|
|
||||||
# Forward pass through the transformer modules.
|
# Forward pass through the transformer modules.
|
||||||
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
||||||
|
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
|
||||||
decoder_in = torch.zeros(
|
decoder_in = torch.zeros(
|
||||||
(self.config.chunk_size, batch_size, self.config.dim_model),
|
(self.config.chunk_size, batch_size, self.config.dim_model),
|
||||||
dtype=pos_embed.dtype,
|
dtype=pos_embed.dtype,
|
||||||
|
|
|
@ -26,21 +26,29 @@ class DiffusionConfig:
|
||||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||||
Those are: `input_shapes` and `output_shapes`.
|
Those are: `input_shapes` and `output_shapes`.
|
||||||
|
|
||||||
|
Notes on the inputs and outputs:
|
||||||
|
- "observation.state" is required as an input key.
|
||||||
|
- At least one key starting with "observation.image is required as an input.
|
||||||
|
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
||||||
|
views.
|
||||||
|
Right now we only support all images having the same shape.
|
||||||
|
- "action" is required as an output key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||||
current step and additional steps going back).
|
current step and additional steps going back).
|
||||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||||
See `DiffusionPolicy.select_action` for more details.
|
See `DiffusionPolicy.select_action` for more details.
|
||||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
The key represents the input data name, and the value is a list indicating the dimensions
|
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
of the corresponding data. For example, "observation.image" refers to an input from
|
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||||
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
include batch dimension or temporal dimension.
|
||||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||||
The key represents the output data name, and the value is a list indicating the dimensions
|
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||||
|
|
|
@ -31,6 +31,15 @@ class TDMPCConfig:
|
||||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||||
action repeats in Q-learning or ask your favorite chatbot)
|
action repeats in Q-learning or ask your favorite chatbot)
|
||||||
horizon: Horizon for model predictive control.
|
horizon: Horizon for model predictive control.
|
||||||
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
|
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
|
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||||
|
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||||
|
include batch dimension or temporal dimension.
|
||||||
|
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||||
|
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
|
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||||
|
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
fps: 30
|
||||||
|
|
||||||
|
env:
|
||||||
|
name: dora
|
||||||
|
task: DoraAloha2-v0
|
||||||
|
state_dim: 14
|
||||||
|
action_dim: 14
|
||||||
|
fps: ${fps}
|
||||||
|
episode_length: 400
|
||||||
|
gym:
|
||||||
|
fps: ${fps}
|
|
@ -0,0 +1,101 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
seed: 1000
|
||||||
|
dataset_repo_id: cadene/aloha_v2_static_dora_test
|
||||||
|
|
||||||
|
override_dataset_stats:
|
||||||
|
observation.images.cam_right_wrist:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_left_wrist:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_high:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_low:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
|
||||||
|
training:
|
||||||
|
offline_steps: 80000
|
||||||
|
online_steps: 0
|
||||||
|
eval_freq: -1
|
||||||
|
save_freq: 10000
|
||||||
|
log_freq: 100
|
||||||
|
save_model: true
|
||||||
|
|
||||||
|
batch_size: 8
|
||||||
|
lr: 1e-5
|
||||||
|
lr_backbone: 1e-5
|
||||||
|
weight_decay: 1e-4
|
||||||
|
grad_clip_norm: 10
|
||||||
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||||
|
|
||||||
|
eval:
|
||||||
|
n_episodes: 50
|
||||||
|
batch_size: 50
|
||||||
|
|
||||||
|
# See `configuration_act.py` for more details.
|
||||||
|
policy:
|
||||||
|
name: act
|
||||||
|
|
||||||
|
# Input / output structure.
|
||||||
|
n_obs_steps: 1
|
||||||
|
chunk_size: 100 # chunk_size
|
||||||
|
n_action_steps: 100
|
||||||
|
|
||||||
|
input_shapes:
|
||||||
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
observation.images.cam_right_wrist: [3, 480, 640]
|
||||||
|
observation.images.cam_left_wrist: [3, 480, 640]
|
||||||
|
observation.images.cam_high: [3, 480, 640]
|
||||||
|
observation.images.cam_low: [3, 480, 640]
|
||||||
|
observation.state: ["${env.state_dim}"]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes:
|
||||||
|
observation.images.cam_right_wrist: mean_std
|
||||||
|
observation.images.cam_left_wrist: mean_std
|
||||||
|
observation.images.cam_high: mean_std
|
||||||
|
observation.images.cam_low: mean_std
|
||||||
|
observation.state: mean_std
|
||||||
|
output_normalization_modes:
|
||||||
|
action: mean_std
|
||||||
|
|
||||||
|
# Architecture.
|
||||||
|
# Vision backbone.
|
||||||
|
vision_backbone: resnet18
|
||||||
|
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||||
|
replace_final_stride_with_dilation: false
|
||||||
|
# Transformer layers.
|
||||||
|
pre_norm: false
|
||||||
|
dim_model: 512
|
||||||
|
n_heads: 8
|
||||||
|
dim_feedforward: 3200
|
||||||
|
feedforward_activation: relu
|
||||||
|
n_encoder_layers: 4
|
||||||
|
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||||
|
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||||
|
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||||
|
n_decoder_layers: 1
|
||||||
|
# VAE.
|
||||||
|
use_vae: true
|
||||||
|
latent_dim: 32
|
||||||
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
|
# Inference.
|
||||||
|
temporal_ensemble_momentum: null
|
||||||
|
|
||||||
|
# Training and loss computation.
|
||||||
|
dropout: 0.1
|
||||||
|
kl_weight: 10.0
|
Loading…
Reference in New Issue