more verbose var naming + update docstrings

This commit is contained in:
Akshay Kashyap 2024-05-29 11:54:53 -07:00
parent ec0ae30fb8
commit 9bc8f89a3a
3 changed files with 72 additions and 64 deletions

View File

@ -61,7 +61,7 @@ class OctoConfig:
n_readouts_per_step: Number of larned readout heads to use per observation step.
n_layers: Number of transformer encoder layers.
n_heads: Number of heads in the transformer.
d_ffn: Dimension of the feedforward network in the transformer.
dim_feedforward: Dimension of the feedforward network in the transformer.
p_dropout: Dropout rate in the attention and feedforward networks.
time_dim: Dimension of the denoising iteration index feature projection.
n_diffusion_head_layers: Number of layers in the action diffusion head.
@ -124,7 +124,7 @@ class OctoConfig:
n_readouts_per_step: int = 1
n_layers: int = 12
n_heads: int = 6
d_ffn: int = 1536
dim_feedforward: int = 1536
p_dropout: int = 0.0
time_dim: int = 32
n_diffusion_head_layers: int = 3
@ -140,7 +140,7 @@ class OctoConfig:
clip_sample_range: float = 1.0
# Inference
num_inference_steps: int | None = None
num_inference_steps: int | None = 20
# Loss computation
do_mask_loss_for_padding: bool = False

View File

@ -67,9 +67,11 @@ class OctoPolicy(nn.Module, PyTorchModelHubMixin):
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
if config is None:
config = OctoConfig()
self.config = config
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
@ -166,54 +168,60 @@ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMSche
class OctoModel(nn.Module):
"""An overview of this minimal Octo model implementation:
There are two main components:
1) OctoTransformer, which processes an input sequence of state and image tokens (collectively called
"observation tokens") and readout tokens as follows:
- Normalized state inputs of shape (`n_obs_steps`, `state_dim`) are projected to (`n_obs_steps`, `embed_dim`).
- Feature maps of the images generated by a vision encoder are flattened along the spatial dimension and
projected to (`n_obs_steps`, `n_img_features`, `embed_dim`).
- The above are concatenated to form a sequence of observation tokens of shape (`n_obs_steps`, `n_obs_tokens_per_step`, `embed_dim`).
- Additionally, learned "readout" tokens of shape (`n_obs_steps`, `n_readouts_per_step`, `embed_dim`) appended
after tokens of each observation step to form the final input sequence for the transformer encoder.
A causal mask (see make_causal_mask for an example viz) is used to prevent:
a) observation and readout tokens from attending to any future tokens,
b) observation tokens from attending to any future or prior readout tokens, and
c) readout tokens from attending to prior readout tokens.
2) OctoDiffusionActionHead, which predicts the noise to remove from a noisy trajectory, conditioned on the mean
of all readout embeddings and a projection of the denoising iteration K.
An example, with `n_obs_steps`=2, `n_readouts_per_step`=1, a 7-DoF state (joint angles, gripper, etc.)
and a 96x96 wrist image (which would result in 256x6x6 feature maps from resnet18):
`n_obs_tokens_per_step` = 6*6 + 1 = 37
`input_seq_len` = (`n_obs_tokens_per_step` + `n_readouts_per_step`) * `n_obs_steps` = 38*2 = 76
---------------------------------------------------------------------------
Token Index | 0 | 1 | 2 | 3 | 4 | ... | 36 | 37 | 38 | 39 | ... | 74 | 75 |
---------------------------------------------------------------------------
Obs Timestep| 0 | 0 | 0 | 0 | 0 | ... | 1 | 1 | 1 | 1 | ... | 1 | 1 |
---------------------------------------------------------------------------
Token Type |obs|obs|obs|obs|obs| ... |obs |rout|obs |obs | ... |obs |rout|
------------|---|---|---|---|---|-----|----|----|----|----|-----|----|----|
| |
V V
<r_embed_1> <r_embed_2>
| |
--------> (Mean) <--------
|
V
<noisy_sample>, <K_proj> --> (Action Diffusion Head) --> <noise_pred>
Note that this implementation does not (yet) include certain features from the original Octo implementation:
1) Language and Goal Conditioning: The original Octo supports conditioning on language and goal images, which
would be tokenized and prepended to the input sequence.
1) Multiple trajectory generation: The original Octo generates a trajectory for each readout token (i.e,
trajectory starting at each observation step). This implementation only generates a single trajectory using
the mean of all readout tokens.
2) MAP over multiple readout tokens: The original Octo has an option to use Multihead Attention Pooling over
multiple readout tokens for each observation step. This supports multiple readout tokens but utilizes a
simple mean pooling over them.
"""
def __init__(self, config: OctoConfig):
"""An overview of this minimal Octo model implementation:
There are two main components:
1) OctoTransformer, which processes an input sequence of state and image tokens (collectively called
"observation tokens") and readout tokens as follows:
- Normalized state inputs of shape (`n_obs_steps`, `state_dim`) are projected to (`n_obs_steps`, `embed_dim`).
- Feature maps of the images generated by a vision encoder are flattened along the spatial dimension and
projected to (`n_obs_steps`, `n_img_features`, `embed_dim`).
- The above are concatenated to form a sequence of observation tokens of shape (`n_obs_steps`, `n_obs_tokens_per_step`, `embed_dim`).
- Additionally, learned "readout" tokens of shape (`n_obs_steps`, `n_readouts_per_step`, `embed_dim`) appended
after tokens of each observation step to form the final input sequence for the transformer encoder.
A causal mask (see make_causal_mask for an example viz) is used to prevent:
a) observation and readout tokens from attending to any future tokens,
b) observation tokens from attending to any readout tokens, and
c) readout tokens from attending to prior readout tokens.
2) OctoDiffusionActionHead, which predicts the noise to remove from a noisy trajectory, conditioned on the mean
of all readout embeddings and a projection of the denoising iteration K.
An example, with `n_obs_steps`=2, `n_readouts_per_step`=1, a 7-DoF state (joint angles, gripper, etc.)
and a 96x96 wrist image (which would result in 256x6x6 feature maps using resnet18):
`n_obs_tokens_per_step` = 6*6 + 1 + 1 = 38
---------------------------------------------------------------------------
Token Index | 0 | 1 | 2 | 3 | 4 | ... | 36 | 37 | 38 | 39 | ... | 74 | 75 |
---------------------------------------------------------------------------
Obs Timestep| 0 | 0 | 0 | 0 | 0 | ... | 1 | 1 | 1 | 1 | ... | 1 | 1 |
---------------------------------------------------------------------------
Token Type |obs|obs|obs|obs|obs| ... |obs |rout|obs |obs | ... |obs |rout|
------------|---|---|---|---|---|-----|----|----|----|----|-----|----|----|
| |
V V
<r_embed_1> <r_embed_2>
| |
--------> (Mean) <--------
|
V
<noisy_sample>, <K_proj> --> (Action Diffusion Head) --> <noise_pred>
Note that this implementation does not (yet) include certain features from the original Octo implementation:
1) Language and Goal Conditioning: The original Octo supports conditioning on language and goal images, which
would be tokenized and prepended to the input sequence.
1) Multiple trajectory generation: The original Octo generates a trajectory for each readout token (i.e,
trajectory starting at each observation step). This implementation only generates a single trajectory using
the mean of all readout tokens.
2) MAP over multiple readout tokens: The original Octo has an option to use Multihead Attention Pooling over
multiple readout tokens for each observation step. This supports multiple readout tokens but utilizes a
simple mean pooling over them.
"""
Args:
config: OctoConfig, Policy configuration class instance.
"""
super().__init__()
self.config = config
@ -507,15 +515,15 @@ def make_causal_mask(n_obs_tokens_per_step, n_obs_steps, n_readouts_per_step):
class OctoTransformer(nn.Module):
"""Transformer Encoder for Octo, as described above.
Args:
config: OctoConfig, configuration class instance.
img_dim: int, dimension of the image feature patches.
n_obs_tokens_per_step: int, number of observation tokens in the input sequence per observation step.
"""
"""Transformer Encoder for Octo, as described above."""
def __init__(self, config: OctoConfig, img_dim: int, n_obs_tokens_per_step: int):
"""
Args:
config: OctoConfig, Policy configuration class instance.
img_dim: int, dimension of the image feature patches.
n_obs_tokens_per_step: int, number of observation tokens in the input sequence per observation step.
"""
super().__init__()
self.config = config
@ -547,7 +555,7 @@ class OctoTransformer(nn.Module):
encoder_layers = TransformerEncoderLayer(
config.embed_dim,
config.n_heads,
dim_feedforward=config.d_ffn,
dim_feedforward=config.dim_feedforward,
dropout=config.p_dropout,
batch_first=True,
norm_first=True,
@ -712,13 +720,13 @@ class OctoMLPResNet(nn.Module):
class OctoDiffusionActionHead(nn.Module):
"""Diffusion Action Head for Octo, as described above.
Args:
config: OctoConfig, configuration class instance.
"""
"""Diffusion Action Head for Octo, as described above."""
def __init__(self, config: OctoConfig):
"""
Args:
config: OctoConfig, Policy configuration class instance.
"""
super().__init__()
self.fourier_feature_embedder = OctoFourierFeatures(config.time_dim)

View File

@ -61,7 +61,7 @@ policy:
n_readouts: 1
n_layers: 12
n_heads: 6
d_ffn: 1536
dim_feedforward: 1536
p_dropout: 0.
time_dim: 32
n_diffusion_head_layers: 3