diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b3411e11..a86193b8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -146,7 +146,8 @@ jobs: device=cpu \ save_model=true \ save_freq=2 \ - horizon=20 \ + policy.n_action_steps=20 \ + policy.chunk_size=20 \ policy.batch_size=2 \ hydra.run.dir=tests/outputs/act/ diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py new file mode 100644 index 00000000..74ed270e --- /dev/null +++ b/lerobot/common/policies/act/configuration_act.py @@ -0,0 +1,114 @@ +from dataclasses import dataclass, field + + +@dataclass +class ActionChunkingTransformerConfig: + """Configuration class for the Action Chunking Transformers policy. + + Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". + + The parameters you will most likely need to change are the ones which depend on the environment / sensors. + Those are: `state_dim`, `action_dim` and `camera_names`. + + Args: + state_dim: Dimensionality of the observation state space (excluding images). + action_dim: Dimensionality of the action space. + n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the + current step and additional steps going back). + camera_names: The (unique) set of names for the cameras. + chunk_size: The size of the action prediction "chunks" in units of environment steps. + n_action_steps: The number of action steps to run in the environment for one invocation of the policy. + This should be no greater than the chunk size. For example, if the chunk size size 100, you may + set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the + environment, and throws the other 50 out. + image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in + [0, 1]) for normalization. + image_normalization_std: Value by which to divide the input image pixels (after the mean has been + subtracted). + vision_backbone: Name of the torchvision resnet backbone to use for encoding images. + use_pretrained_backbone: Whether the backbone should be initialized with ImageNet, pretrained weights + from torchvision. + replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated + convolution. + pre_norm: Whether to use "pre-norm" in the transformer blocks. + d_model: The transformer blocks' main hidden dimension. + n_heads: The number of heads to use in the transformer blocks' multi-head attention. + dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward + layers. + feedforward_activation: The activation to use in the transformer block's feed-forward layers. + n_encoder_layers: The number of transformer layers to use for the transformer encoder. + n_decoder_layers: The number of transformer layers to use for the transformer decoder. + use_vae: Whether to use a variational objective during training. This introduces another transformer + which is used as the VAE's encoder (not to be confused with the transformer encoder - see + documentation in the policy class). + latent_dim: The VAE's latent dimension. + n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder. + use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given + environment step. + dropout: Dropout to use in the transformer layers (see code for details). + kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective + is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. + """ + + # Environment. + state_dim: int = 14 + action_dim: int = 14 + + # Inputs / output structure. + n_obs_steps: int = 1 + camera_names: list[str] = field(default_factory=lambda: ["top"]) + chunk_size: int = 100 + n_action_steps: int = 100 + + # Vision preprocessing. + image_normalization_mean: tuple[float, float, float] = field( + default_factory=lambda: [0.485, 0.456, 0.406] + ) + image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225]) + + # Architecture. + # Vision backbone. + vision_backbone: str = "resnet18" + use_pretrained_backbone: bool = True + replace_final_stride_with_dilation: int = False + # Transformer layers. + pre_norm: bool = False + d_model: int = 512 + n_heads: int = 8 + dim_feedforward: int = 3200 + feedforward_activation: str = "relu" + n_encoder_layers: int = 4 + n_decoder_layers: int = 1 + # VAE. + use_vae: bool = True + latent_dim: int = 32 + n_vae_encoder_layers: int = 4 + + # Inference. + use_temporal_aggregation: bool = False + + # Training and loss computation. + dropout: float = 0.1 + kl_weight: float = 10.0 + + # --- + # TODO(alexander-soare): Remove these from the policy config. + batch_size: int = 8 + lr: float = 1e-5 + lr_backbone: float = 1e-5 + weight_decay: float = 1e-4 + grad_clip_norm: float = 10 + utd: int = 1 + + def __post_init__(self): + """Input validation.""" + if not self.vision_backbone.startswith("resnet"): + raise ValueError("`vision_backbone` must be one of the ResNet variants.") + if self.use_temporal_aggregation: + raise NotImplementedError("Temporal aggregation is not yet implemented.") + if self.n_action_steps > self.chunk_size: + raise ValueError( + "The chunk size is the upper bound for the number of action steps per model invocation." + ) + if self.camera_names != ["top"]: + raise ValueError("For now, `camera_names` can only be ['top']") diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/modeling_act.py similarity index 80% rename from lerobot/common/policies/act/policy.py rename to lerobot/common/policies/act/modeling_act.py index 24667795..1361e071 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -20,7 +20,7 @@ from torch import Tensor, nn from torchvision.models._utils import IntermediateLayerGetter from torchvision.ops.misc import FrozenBatchNorm2d -from lerobot.common.utils import get_safe_torch_device +from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig class ActionChunkingTransformerPolicy(nn.Module): @@ -65,7 +65,7 @@ class ActionChunkingTransformerPolicy(nn.Module): "ActionChunkingTransformerPolicy does not handle multiple observation steps." ) - def __init__(self, cfg, device): + def __init__(self, cfg: ActionChunkingTransformerConfig): """ TODO(alexander-soare): Add documentation for all parameters once we have model configs established. """ @@ -73,79 +73,64 @@ class ActionChunkingTransformerPolicy(nn.Module): if getattr(cfg, "n_obs_steps", 1) != 1: raise ValueError(self._multiple_obs_steps_not_handled_msg) self.cfg = cfg - self.n_action_steps = cfg.n_action_steps - self.device = get_safe_torch_device(device) - self.camera_names = cfg.camera_names - self.use_vae = cfg.use_vae - self.horizon = cfg.horizon - self.d_model = cfg.d_model - - transformer_common_kwargs = dict( # noqa: C408 - d_model=self.d_model, - num_heads=cfg.num_heads, - dim_feedforward=cfg.dim_feedforward, - dropout=cfg.dropout, - activation=cfg.activation, - normalize_before=cfg.pre_norm, - ) # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). - if self.use_vae: - self.vae_encoder = _TransformerEncoder(num_layers=cfg.vae_enc_layers, **transformer_common_kwargs) - self.vae_encoder_cls_embed = nn.Embedding(1, self.d_model) + if self.cfg.use_vae: + self.vae_encoder = _TransformerEncoder(cfg) + self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model) # Projection layer for joint-space configuration to hidden dimension. - self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model) + self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model) # Projection layer for action (joint-space target) to hidden dimension. - self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, self.d_model) + self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, cfg.d_model) self.latent_dim = cfg.latent_dim # Projection layer from the VAE encoder's output to the latent distribution's parameter space. - self.vae_encoder_latent_output_proj = nn.Linear(self.d_model, self.latent_dim * 2) + self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2) # Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch # dimension. self.register_buffer( "vae_encoder_pos_enc", - _create_sinusoidal_position_embedding(1 + 1 + self.horizon, self.d_model).unsqueeze(0), + _create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0), ) # Backbone for image feature extraction. self.image_normalizer = transforms.Normalize( - mean=cfg.image_normalization.mean, std=cfg.image_normalization.std + mean=cfg.image_normalization_mean, std=cfg.image_normalization_std ) - backbone_model = getattr(torchvision.models, cfg.backbone)( - replace_stride_with_dilation=[False, False, cfg.dilation], - pretrained=cfg.pretrained_backbone, + backbone_model = getattr(torchvision.models, cfg.vision_backbone)( + replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation], + pretrained=cfg.use_pretrained_backbone, norm_layer=FrozenBatchNorm2d, ) + # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature + # map). # Note: The forward method of this returns a dict: {"feature_map": output}. self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) # Transformer (acts as VAE decoder when training with the variational objective). - self.encoder = _TransformerEncoder(num_layers=cfg.enc_layers, **transformer_common_kwargs) - self.decoder = _TransformerDecoder(num_layers=cfg.dec_layers, **transformer_common_kwargs) + self.encoder = _TransformerEncoder(cfg) + self.decoder = _TransformerDecoder(cfg) # Transformer encoder input projections. The tokens will be structured like # [latent, robot_state, image_feature_map_pixels]. - self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, self.d_model) - self.encoder_latent_input_proj = nn.Linear(self.latent_dim, self.d_model) + self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model) + self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model) self.encoder_img_feat_input_proj = nn.Conv2d( - backbone_model.fc.in_features, self.d_model, kernel_size=1 + backbone_model.fc.in_features, cfg.d_model, kernel_size=1 ) # Transformer encoder positional embeddings. - self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, self.d_model) - self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(self.d_model // 2) + self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, cfg.d_model) + self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(cfg.d_model // 2) # Transformer decoder. # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). - self.decoder_pos_embed = nn.Embedding(self.horizon, self.d_model) + self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model) # Final action regression head on the output of the transformer's decoder. - self.action_head = nn.Linear(self.d_model, cfg.action_dim) + self.action_head = nn.Linear(cfg.d_model, cfg.action_dim) self._reset_parameters() - self._create_optimizer() - self.to(self.device) def _create_optimizer(self): optimizer_params_dicts = [ @@ -173,8 +158,8 @@ class ActionChunkingTransformerPolicy(nn.Module): def reset(self): """This should be called whenever the environment is reset.""" - if self.n_action_steps is not None: - self._action_queue = deque([], maxlen=self.n_action_steps) + if self.cfg.n_action_steps is not None: + self._action_queue = deque([], maxlen=self.cfg.n_action_steps) @torch.no_grad def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: @@ -184,8 +169,8 @@ class ActionChunkingTransformerPolicy(nn.Module): queue is empty. """ if len(self._action_queue) == 0: - # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape - # (n_action_steps, batch_size, *), hence the transpose. + # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has + # shape (n_action_steps, batch_size, *), hence the transpose. self._action_queue.extend(self.select_actions(batch).transpose(0, 1)) return self._action_queue.popleft() @@ -197,20 +182,7 @@ class ActionChunkingTransformerPolicy(nn.Module): action = self.forward(batch, return_loss=False) - if self.cfg.temporal_agg: - # TODO(rcadene): implement temporal aggregation - raise NotImplementedError() - # all_time_actions[[t], t:t+num_queries] = action - # actions_for_curr_step = all_time_actions[:, t] - # actions_populated = torch.all(actions_for_curr_step != 0, axis=1) - # actions_for_curr_step = actions_for_curr_step[actions_populated] - # k = 0.01 - # exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) - # exp_weights = exp_weights / exp_weights.sum() - # exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1) - # raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) - - return action[: self.n_action_steps] + return action[: self.cfg.n_action_steps] def __call__(self, *args, **kwargs) -> dict: # TODO(alexander-soare): Temporary bridge until we know what to do about the `update` method. @@ -251,9 +223,9 @@ class ActionChunkingTransformerPolicy(nn.Module): self.train() num_slices = self.cfg.batch_size - batch_size = self.cfg.horizon * num_slices + batch_size = self.cfg.chunk_size * num_slices - assert batch_size % self.cfg.horizon == 0 + assert batch_size % self.cfg.chunk_size == 0 assert batch_size % num_slices == 0 loss = self.forward(batch, return_loss=True)["loss"] @@ -324,7 +296,7 @@ class ActionChunkingTransformerPolicy(nn.Module): Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the latent dimension. """ - if self.use_vae and self.training: + if self.cfg.use_vae and self.training: assert ( actions is not None ), "actions must be provided when using the variational objective in training mode." @@ -332,7 +304,7 @@ class ActionChunkingTransformerPolicy(nn.Module): batch_size = robot_state.shape[0] # Prepare the latent for input to the transformer encoder. - if self.use_vae and actions is not None: + if self.cfg.use_vae and actions is not None: # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. cls_embed = einops.repeat( self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size @@ -367,7 +339,7 @@ class ActionChunkingTransformerPolicy(nn.Module): # Camera observation features and positional embeddings. all_cam_features = [] all_cam_pos_embeds = [] - for cam_id, _ in enumerate(self.camera_names): + for cam_id, _ in enumerate(self.cfg.camera_names): cam_features = self.backbone(image[:, cam_id])["feature_map"] cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) @@ -399,7 +371,9 @@ class ActionChunkingTransformerPolicy(nn.Module): # Forward pass through the transformer modules. encoder_out = self.encoder(encoder_in, pos_embed=pos_embed) decoder_in = torch.zeros( - (self.horizon, batch_size, self.d_model), dtype=pos_embed.dtype, device=pos_embed.device + (self.cfg.chunk_size, batch_size, self.cfg.d_model), + dtype=pos_embed.dtype, + device=pos_embed.device, ) decoder_out = self.decoder( decoder_in, @@ -426,16 +400,10 @@ class ActionChunkingTransformerPolicy(nn.Module): class _TransformerEncoder(nn.Module): """Convenience module for running multiple encoder layers, maybe followed by normalization.""" - def __init__(self, num_layers: int, **encoder_layer_kwargs: dict): + def __init__(self, cfg: ActionChunkingTransformerConfig): super().__init__() - self.layers = nn.ModuleList( - [_TransformerEncoderLayer(**encoder_layer_kwargs) for _ in range(num_layers)] - ) - self.norm = ( - nn.LayerNorm(encoder_layer_kwargs["d_model"]) - if encoder_layer_kwargs["normalize_before"] - else nn.Identity() - ) + self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)]) + self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity() def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor: for layer in self.layers: @@ -445,39 +413,31 @@ class _TransformerEncoder(nn.Module): class _TransformerEncoderLayer(nn.Module): - def __init__( - self, - d_model: int, - num_heads: int, - dim_feedforward: int, - dropout: float, - activation: str, - normalize_before: bool, - ): + def __init__(self, cfg: ActionChunkingTransformerConfig): super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout) + self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) # Feed forward layers. - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) + self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward) + self.dropout = nn.Dropout(cfg.dropout) + self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model) - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(cfg.d_model) + self.norm2 = nn.LayerNorm(cfg.d_model) + self.dropout1 = nn.Dropout(cfg.dropout) + self.dropout2 = nn.Dropout(cfg.dropout) - self.activation = _get_activation_fn(activation) - self.normalize_before = normalize_before + self.activation = _get_activation_fn(cfg.feedforward_activation) + self.pre_norm = cfg.pre_norm def forward(self, x, pos_embed: Tensor | None = None) -> Tensor: skip = x - if self.normalize_before: + if self.pre_norm: x = self.norm1(x) q = k = x if pos_embed is None else x + pos_embed x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights x = skip + self.dropout1(x) - if self.normalize_before: + if self.pre_norm: skip = x x = self.norm2(x) else: @@ -485,20 +445,17 @@ class _TransformerEncoderLayer(nn.Module): skip = x x = self.linear2(self.dropout(self.activation(self.linear1(x)))) x = skip + self.dropout2(x) - if not self.normalize_before: + if not self.pre_norm: x = self.norm2(x) return x class _TransformerDecoder(nn.Module): - def __init__(self, num_layers: int, **decoder_layer_kwargs): + def __init__(self, cfg: ActionChunkingTransformerConfig): """Convenience module for running multiple decoder layers followed by normalization.""" super().__init__() - self.layers = nn.ModuleList( - [_TransformerDecoderLayer(**decoder_layer_kwargs) for _ in range(num_layers)] - ) - self.num_layers = num_layers - self.norm = nn.LayerNorm(decoder_layer_kwargs["d_model"]) + self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)]) + self.norm = nn.LayerNorm(cfg.d_model) def forward( self, @@ -517,33 +474,25 @@ class _TransformerDecoder(nn.Module): class _TransformerDecoderLayer(nn.Module): - def __init__( - self, - d_model: int, - num_heads: int, - dim_feedforward: int, - dropout: float, - activation: str, - normalize_before: bool, - ): + def __init__(self, cfg: ActionChunkingTransformerConfig): super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout) - self.multihead_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout) + self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) + self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) # Feed forward layers. - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) + self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward) + self.dropout = nn.Dropout(cfg.dropout) + self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model) - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.norm3 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.dropout3 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(cfg.d_model) + self.norm2 = nn.LayerNorm(cfg.d_model) + self.norm3 = nn.LayerNorm(cfg.d_model) + self.dropout1 = nn.Dropout(cfg.dropout) + self.dropout2 = nn.Dropout(cfg.dropout) + self.dropout3 = nn.Dropout(cfg.dropout) - self.activation = _get_activation_fn(activation) - self.normalize_before = normalize_before + self.activation = _get_activation_fn(cfg.feedforward_activation) + self.pre_norm = cfg.pre_norm def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor: return tensor if pos_embed is None else tensor + pos_embed @@ -566,12 +515,12 @@ class _TransformerDecoderLayer(nn.Module): (DS, B, C) tensor of decoder output features. """ skip = x - if self.normalize_before: + if self.pre_norm: x = self.norm1(x) q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights x = skip + self.dropout1(x) - if self.normalize_before: + if self.pre_norm: skip = x x = self.norm2(x) else: @@ -583,7 +532,7 @@ class _TransformerDecoderLayer(nn.Module): value=encoder_out, )[0] # select just the output, not the attention weights x = skip + self.dropout2(x) - if self.normalize_before: + if self.pre_norm: skip = x x = self.norm3(x) else: @@ -591,7 +540,7 @@ class _TransformerDecoderLayer(nn.Module): skip = x x = self.linear2(self.dropout(self.activation(self.linear1(x)))) x = skip + self.dropout3(x) - if not self.normalize_before: + if not self.pre_norm: x = self.norm3(x) return x diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index a287614d..f0454b8e 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,3 +1,10 @@ +import inspect + +from omegaconf import OmegaConf + +from lerobot.common.utils import get_safe_torch_device + + def make_policy(cfg): if cfg.policy.name == "tdmpc": from lerobot.common.policies.tdmpc.policy import TDMPCPolicy @@ -19,10 +26,22 @@ def make_policy(cfg): **cfg.policy, ) elif cfg.policy.name == "act": - from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy + from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig + from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy - policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device) - policy.to(cfg.device) + expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters) + assert set(cfg.policy).issuperset( + expected_kwargs + ), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}" + policy_cfg = ActionChunkingTransformerConfig( + **{ + k: v + for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items() + if k in expected_kwargs + } + ) + policy = ActionChunkingTransformerPolicy(policy_cfg) + policy.to(get_safe_torch_device(cfg.device)) else: raise ValueError(cfg.policy.name) diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index e2074b46..bd883613 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -8,61 +8,65 @@ eval_freq: 10000 save_freq: 100000 log_freq: 250 -horizon: 100 n_obs_steps: 1 # when temporal_agg=False, n_action_steps=horizon -n_action_steps: ${horizon} +# See `configuration_act.py` for more details. policy: name: act pretrained_model_path: + # Environment. + # Inherit these from the environment. + state_dim: ??? + action_dim: ??? + + # Inputs / output structure. + n_obs_steps: ${n_obs_steps} + camera_names: [top] # [top, front_close, left_pillar, right_pillar] + chunk_size: 100 # chunk_size + n_action_steps: 100 + + # Vision preprocessing. + image_normalization_mean: [0.485, 0.456, 0.406] + image_normalization_std: [0.229, 0.224, 0.225] + + # Architecture. + # Vision backbone. + vision_backbone: resnet18 + use_pretrained_backbone: true + replace_final_stride_with_dilation: false + # Transformer layers. + pre_norm: false + d_model: 512 + n_heads: 8 + dim_feedforward: 3200 + feedforward_activation: relu + n_encoder_layers: 4 + n_decoder_layers: 1 + # VAE. + use_vae: true + latent_dim: 32 + n_vae_encoder_layers: 4 + + # Inference. + use_temporal_aggregation: false + + # Training and loss computation. + dropout: 0.1 + kl_weight: 10.0 + + # --- + # TODO(alexander-soare): Remove these from the policy config. + batch_size: 8 lr: 1e-5 lr_backbone: 1e-5 - pretrained_backbone: true weight_decay: 1e-4 grad_clip_norm: 10 - backbone: resnet18 - horizon: ${horizon} # chunk_size - kl_weight: 10 - d_model: 512 - dim_feedforward: 3200 - vae_enc_layers: 4 - enc_layers: 4 - dec_layers: 1 - num_heads: 8 - #camera_names: [top, front_close, left_pillar, right_pillar] - camera_names: [top] - dilation: false - dropout: 0.1 - pre_norm: false - activation: relu - latent_dim: 32 - - use_vae: true - - batch_size: 8 - - per_alpha: 0.6 - per_beta: 0.4 - - balanced_sampling: false utd: 1 - n_obs_steps: ${n_obs_steps} - n_action_steps: ${n_action_steps} - - temporal_agg: false - - state_dim: 14 - action_dim: 14 - - image_normalization: - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - delta_timestamps: - observation.images.top: [0.0] - observation.state: [0.0] - action: "[i / ${fps} for i in range(${horizon})]" + observation.images.top: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]" + observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]" + action: "[i / ${fps} for i in range(${policy.chunk_size})]" diff --git a/tests/test_available.py b/tests/test_available.py index be74a42a..b25a921f 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -18,7 +18,7 @@ from lerobot.common.datasets.xarm import XarmDataset from lerobot.common.datasets.aloha import AlohaDataset from lerobot.common.datasets.pusht import PushtDataset -from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy +from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.diffusion.policy import DiffusionPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy