more verbose var naming + update docstrings
This commit is contained in:
parent
ec0ae30fb8
commit
9bc8f89a3a
|
@ -61,7 +61,7 @@ class OctoConfig:
|
|||
n_readouts_per_step: Number of larned readout heads to use per observation step.
|
||||
n_layers: Number of transformer encoder layers.
|
||||
n_heads: Number of heads in the transformer.
|
||||
d_ffn: Dimension of the feedforward network in the transformer.
|
||||
dim_feedforward: Dimension of the feedforward network in the transformer.
|
||||
p_dropout: Dropout rate in the attention and feedforward networks.
|
||||
time_dim: Dimension of the denoising iteration index feature projection.
|
||||
n_diffusion_head_layers: Number of layers in the action diffusion head.
|
||||
|
@ -124,7 +124,7 @@ class OctoConfig:
|
|||
n_readouts_per_step: int = 1
|
||||
n_layers: int = 12
|
||||
n_heads: int = 6
|
||||
d_ffn: int = 1536
|
||||
dim_feedforward: int = 1536
|
||||
p_dropout: int = 0.0
|
||||
time_dim: int = 32
|
||||
n_diffusion_head_layers: int = 3
|
||||
|
@ -140,7 +140,7 @@ class OctoConfig:
|
|||
clip_sample_range: float = 1.0
|
||||
|
||||
# Inference
|
||||
num_inference_steps: int | None = None
|
||||
num_inference_steps: int | None = 20
|
||||
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
|
|
@ -67,9 +67,11 @@ class OctoPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = OctoConfig()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
|
@ -166,54 +168,60 @@ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMSche
|
|||
|
||||
|
||||
class OctoModel(nn.Module):
|
||||
"""An overview of this minimal Octo model implementation:
|
||||
|
||||
There are two main components:
|
||||
1) OctoTransformer, which processes an input sequence of state and image tokens (collectively called
|
||||
"observation tokens") and readout tokens as follows:
|
||||
- Normalized state inputs of shape (`n_obs_steps`, `state_dim`) are projected to (`n_obs_steps`, `embed_dim`).
|
||||
- Feature maps of the images generated by a vision encoder are flattened along the spatial dimension and
|
||||
projected to (`n_obs_steps`, `n_img_features`, `embed_dim`).
|
||||
- The above are concatenated to form a sequence of observation tokens of shape (`n_obs_steps`, `n_obs_tokens_per_step`, `embed_dim`).
|
||||
- Additionally, learned "readout" tokens of shape (`n_obs_steps`, `n_readouts_per_step`, `embed_dim`) appended
|
||||
after tokens of each observation step to form the final input sequence for the transformer encoder.
|
||||
A causal mask (see make_causal_mask for an example viz) is used to prevent:
|
||||
a) observation and readout tokens from attending to any future tokens,
|
||||
b) observation tokens from attending to any future or prior readout tokens, and
|
||||
c) readout tokens from attending to prior readout tokens.
|
||||
2) OctoDiffusionActionHead, which predicts the noise to remove from a noisy trajectory, conditioned on the mean
|
||||
of all readout embeddings and a projection of the denoising iteration K.
|
||||
|
||||
An example, with `n_obs_steps`=2, `n_readouts_per_step`=1, a 7-DoF state (joint angles, gripper, etc.)
|
||||
and a 96x96 wrist image (which would result in 256x6x6 feature maps from resnet18):
|
||||
|
||||
`n_obs_tokens_per_step` = 6*6 + 1 = 37
|
||||
`input_seq_len` = (`n_obs_tokens_per_step` + `n_readouts_per_step`) * `n_obs_steps` = 38*2 = 76
|
||||
---------------------------------------------------------------------------
|
||||
Token Index | 0 | 1 | 2 | 3 | 4 | ... | 36 | 37 | 38 | 39 | ... | 74 | 75 |
|
||||
---------------------------------------------------------------------------
|
||||
Obs Timestep| 0 | 0 | 0 | 0 | 0 | ... | 1 | 1 | 1 | 1 | ... | 1 | 1 |
|
||||
---------------------------------------------------------------------------
|
||||
Token Type |obs|obs|obs|obs|obs| ... |obs |rout|obs |obs | ... |obs |rout|
|
||||
------------|---|---|---|---|---|-----|----|----|----|----|-----|----|----|
|
||||
| |
|
||||
V V
|
||||
<r_embed_1> <r_embed_2>
|
||||
| |
|
||||
--------> (Mean) <--------
|
||||
|
|
||||
V
|
||||
<noisy_sample>, <K_proj> --> (Action Diffusion Head) --> <noise_pred>
|
||||
|
||||
Note that this implementation does not (yet) include certain features from the original Octo implementation:
|
||||
1) Language and Goal Conditioning: The original Octo supports conditioning on language and goal images, which
|
||||
would be tokenized and prepended to the input sequence.
|
||||
1) Multiple trajectory generation: The original Octo generates a trajectory for each readout token (i.e,
|
||||
trajectory starting at each observation step). This implementation only generates a single trajectory using
|
||||
the mean of all readout tokens.
|
||||
2) MAP over multiple readout tokens: The original Octo has an option to use Multihead Attention Pooling over
|
||||
multiple readout tokens for each observation step. This supports multiple readout tokens but utilizes a
|
||||
simple mean pooling over them.
|
||||
"""
|
||||
|
||||
def __init__(self, config: OctoConfig):
|
||||
"""An overview of this minimal Octo model implementation:
|
||||
|
||||
There are two main components:
|
||||
1) OctoTransformer, which processes an input sequence of state and image tokens (collectively called
|
||||
"observation tokens") and readout tokens as follows:
|
||||
- Normalized state inputs of shape (`n_obs_steps`, `state_dim`) are projected to (`n_obs_steps`, `embed_dim`).
|
||||
- Feature maps of the images generated by a vision encoder are flattened along the spatial dimension and
|
||||
projected to (`n_obs_steps`, `n_img_features`, `embed_dim`).
|
||||
- The above are concatenated to form a sequence of observation tokens of shape (`n_obs_steps`, `n_obs_tokens_per_step`, `embed_dim`).
|
||||
- Additionally, learned "readout" tokens of shape (`n_obs_steps`, `n_readouts_per_step`, `embed_dim`) appended
|
||||
after tokens of each observation step to form the final input sequence for the transformer encoder.
|
||||
A causal mask (see make_causal_mask for an example viz) is used to prevent:
|
||||
a) observation and readout tokens from attending to any future tokens,
|
||||
b) observation tokens from attending to any readout tokens, and
|
||||
c) readout tokens from attending to prior readout tokens.
|
||||
2) OctoDiffusionActionHead, which predicts the noise to remove from a noisy trajectory, conditioned on the mean
|
||||
of all readout embeddings and a projection of the denoising iteration K.
|
||||
|
||||
An example, with `n_obs_steps`=2, `n_readouts_per_step`=1, a 7-DoF state (joint angles, gripper, etc.)
|
||||
and a 96x96 wrist image (which would result in 256x6x6 feature maps using resnet18):
|
||||
|
||||
`n_obs_tokens_per_step` = 6*6 + 1 + 1 = 38
|
||||
---------------------------------------------------------------------------
|
||||
Token Index | 0 | 1 | 2 | 3 | 4 | ... | 36 | 37 | 38 | 39 | ... | 74 | 75 |
|
||||
---------------------------------------------------------------------------
|
||||
Obs Timestep| 0 | 0 | 0 | 0 | 0 | ... | 1 | 1 | 1 | 1 | ... | 1 | 1 |
|
||||
---------------------------------------------------------------------------
|
||||
Token Type |obs|obs|obs|obs|obs| ... |obs |rout|obs |obs | ... |obs |rout|
|
||||
------------|---|---|---|---|---|-----|----|----|----|----|-----|----|----|
|
||||
| |
|
||||
V V
|
||||
<r_embed_1> <r_embed_2>
|
||||
| |
|
||||
--------> (Mean) <--------
|
||||
|
|
||||
V
|
||||
<noisy_sample>, <K_proj> --> (Action Diffusion Head) --> <noise_pred>
|
||||
|
||||
Note that this implementation does not (yet) include certain features from the original Octo implementation:
|
||||
1) Language and Goal Conditioning: The original Octo supports conditioning on language and goal images, which
|
||||
would be tokenized and prepended to the input sequence.
|
||||
1) Multiple trajectory generation: The original Octo generates a trajectory for each readout token (i.e,
|
||||
trajectory starting at each observation step). This implementation only generates a single trajectory using
|
||||
the mean of all readout tokens.
|
||||
2) MAP over multiple readout tokens: The original Octo has an option to use Multihead Attention Pooling over
|
||||
multiple readout tokens for each observation step. This supports multiple readout tokens but utilizes a
|
||||
simple mean pooling over them.
|
||||
"""
|
||||
Args:
|
||||
config: OctoConfig, Policy configuration class instance.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
@ -507,15 +515,15 @@ def make_causal_mask(n_obs_tokens_per_step, n_obs_steps, n_readouts_per_step):
|
|||
|
||||
|
||||
class OctoTransformer(nn.Module):
|
||||
"""Transformer Encoder for Octo, as described above.
|
||||
|
||||
Args:
|
||||
config: OctoConfig, configuration class instance.
|
||||
img_dim: int, dimension of the image feature patches.
|
||||
n_obs_tokens_per_step: int, number of observation tokens in the input sequence per observation step.
|
||||
"""
|
||||
"""Transformer Encoder for Octo, as described above."""
|
||||
|
||||
def __init__(self, config: OctoConfig, img_dim: int, n_obs_tokens_per_step: int):
|
||||
"""
|
||||
Args:
|
||||
config: OctoConfig, Policy configuration class instance.
|
||||
img_dim: int, dimension of the image feature patches.
|
||||
n_obs_tokens_per_step: int, number of observation tokens in the input sequence per observation step.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
@ -547,7 +555,7 @@ class OctoTransformer(nn.Module):
|
|||
encoder_layers = TransformerEncoderLayer(
|
||||
config.embed_dim,
|
||||
config.n_heads,
|
||||
dim_feedforward=config.d_ffn,
|
||||
dim_feedforward=config.dim_feedforward,
|
||||
dropout=config.p_dropout,
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
|
@ -712,13 +720,13 @@ class OctoMLPResNet(nn.Module):
|
|||
|
||||
|
||||
class OctoDiffusionActionHead(nn.Module):
|
||||
"""Diffusion Action Head for Octo, as described above.
|
||||
|
||||
Args:
|
||||
config: OctoConfig, configuration class instance.
|
||||
"""
|
||||
"""Diffusion Action Head for Octo, as described above."""
|
||||
|
||||
def __init__(self, config: OctoConfig):
|
||||
"""
|
||||
Args:
|
||||
config: OctoConfig, Policy configuration class instance.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.fourier_feature_embedder = OctoFourierFeatures(config.time_dim)
|
||||
|
|
|
@ -61,7 +61,7 @@ policy:
|
|||
n_readouts: 1
|
||||
n_layers: 12
|
||||
n_heads: 6
|
||||
d_ffn: 1536
|
||||
dim_feedforward: 1536
|
||||
p_dropout: 0.
|
||||
time_dim: 32
|
||||
n_diffusion_head_layers: 3
|
||||
|
|
Loading…
Reference in New Issue