Train diffusion pusht_keypoints (#307)

Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
Alexander Soare 2024-07-09 12:35:50 +01:00 committed by GitHub
parent a4d77b99f0
commit cc2f6e7404
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 206 additions and 56 deletions

View File

@ -28,7 +28,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
""" """
# map to expected inputs for the policy # map to expected inputs for the policy
return_observations = {} return_observations = {}
if "pixels" in observations:
if isinstance(observations["pixels"], dict): if isinstance(observations["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
else: else:
@ -51,8 +51,12 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return_observations[imgkey] = img return_observations[imgkey] = img
if "environment_state" in observations:
return_observations["observation.environment_state"] = torch.from_numpy(
observations["environment_state"]
).float()
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
# requirement for "agent_pos" # requirement for "agent_pos"
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float() return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
return return_observations return return_observations

View File

@ -28,7 +28,10 @@ class DiffusionConfig:
Notes on the inputs and outputs: Notes on the inputs and outputs:
- "observation.state" is required as an input key. - "observation.state" is required as an input key.
- Either:
- At least one key starting with "observation.image is required as an input. - At least one key starting with "observation.image is required as an input.
AND/OR
- The key "observation.environment_state" is required as input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera - If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views. Right now we only support all images having the same shape. views. Right now we only support all images having the same shape.
- "action" is required as an output key. - "action" is required as an output key.
@ -155,7 +158,13 @@ class DiffusionConfig:
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
) )
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) == 0 and "observation.environment_state" not in self.input_shapes:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if len(image_keys) > 0:
if self.crop_shape is not None: if self.crop_shape is not None:
for image_key in image_keys: for image_key in image_keys:
if ( if (
@ -175,6 +184,7 @@ class DiffusionConfig:
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we " f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
"expect all image shapes to match." "expect all image shapes to match."
) )
supported_prediction_types = ["epsilon", "sample"] supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types: if self.prediction_type not in supported_prediction_types:
raise ValueError( raise ValueError(

View File

@ -83,16 +83,20 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
self.diffusion = DiffusionModel(config) self.diffusion = DiffusionModel(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.use_env_state = "observation.environment_state" in config.input_shapes
self.reset() self.reset()
def reset(self): def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`""" """Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = { self._queues = {
"observation.images": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps), "observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps), "action": deque(maxlen=self.config.n_action_steps),
} }
if len(self.expected_image_keys) > 0:
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
if self.use_env_state:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad @torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@ -117,6 +121,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
""" """
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# Note: It's important that this happens after stacking the images into a single key. # Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
@ -137,6 +142,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch) loss = self.diffusion.compute_loss(batch)
@ -161,15 +167,20 @@ class DiffusionModel(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.rgb_encoder = DiffusionRgbEncoder(config) # Build observation encoders (depending on which observations are provided).
global_cond_dim = config.input_shapes["observation.state"][0]
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")]) num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
self.unet = DiffusionConditionalUnet1d( self._use_images = False
config, self._use_env_state = False
global_cond_dim=( if num_images > 0:
config.input_shapes["observation.state"][0] + self.rgb_encoder.feature_dim * num_images self._use_images = True
) self.rgb_encoder = DiffusionRgbEncoder(config)
* config.n_obs_steps, global_cond_dim += self.rgb_encoder.feature_dim * num_images
) if "observation.environment_state" in config.input_shapes:
self._use_env_state = True
global_cond_dim += config.input_shapes["observation.environment_state"][0]
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
self.noise_scheduler = _make_noise_scheduler( self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type, config.noise_scheduler_type,
@ -219,24 +230,34 @@ class DiffusionModel(nn.Module):
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor: def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
"""Encode image features and concatenate them all together along with the state vector.""" """Encode image features and concatenate them all together along with the state vector."""
batch_size, n_obs_steps = batch["observation.state"].shape[:2] batch_size, n_obs_steps = batch["observation.state"].shape[:2]
global_cond_feats = [batch["observation.state"]]
# Extract image feature (first combine batch, sequence, and camera index dims). # Extract image feature (first combine batch, sequence, and camera index dims).
if self._use_images:
img_features = self.rgb_encoder( img_features = self.rgb_encoder(
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
) )
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the feature # Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# dim (effectively concatenating the camera features). # feature dim (effectively concatenating the camera features).
img_features = einops.rearrange( img_features = einops.rearrange(
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
) )
# Concatenate state and image features then flatten to (B, global_cond_dim). global_cond_feats.append(img_features)
return torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
if self._use_env_state:
global_cond_feats.append(batch["observation.environment_state"])
# Concatenate features then flatten to (B, global_cond_dim).
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
""" """
This function expects `batch` to have: This function expects `batch` to have:
{ {
"observation.state": (B, n_obs_steps, state_dim) "observation.state": (B, n_obs_steps, state_dim)
"observation.images": (B, n_obs_steps, num_cameras, C, H, W) "observation.images": (B, n_obs_steps, num_cameras, C, H, W)
AND/OR
"observation.environment_state": (B, environment_dim)
} }
""" """
batch_size, n_obs_steps = batch["observation.state"].shape[:2] batch_size, n_obs_steps = batch["observation.state"].shape[:2]
@ -260,13 +281,18 @@ class DiffusionModel(nn.Module):
This function expects `batch` to have (at least): This function expects `batch` to have (at least):
{ {
"observation.state": (B, n_obs_steps, state_dim) "observation.state": (B, n_obs_steps, state_dim)
"observation.images": (B, n_obs_steps, num_cameras, C, H, W) "observation.images": (B, n_obs_steps, num_cameras, C, H, W)
AND/OR
"observation.environment_state": (B, environment_dim)
"action": (B, horizon, action_dim) "action": (B, horizon, action_dim)
"action_is_pad": (B, horizon) "action_is_pad": (B, horizon)
} }
""" """
# Input validation. # Input validation.
assert set(batch).issuperset({"observation.state", "observation.images", "action", "action_is_pad"}) assert set(batch).issuperset({"observation.state", "action", "action_is_pad"})
assert "observation.images" in batch or "observation.environment_state" in batch
n_obs_steps = batch["observation.state"].shape[1] n_obs_steps = batch["observation.state"].shape[1]
horizon = batch["action"].shape[1] horizon = batch["action"].shape[1]
assert horizon == self.config.horizon assert horizon == self.config.horizon

View File

@ -0,0 +1,110 @@
# @package _global_
# Defaults for training for the pusht_keypoints dataset.
# They keypoints are on the vertices of the rectangles that make up the PushT as documented in the PushT
# environment:
# https://github.com/huggingface/gym-pusht/blob/5e2489be9ff99ed9cd47b6c653dda3b7aa844d24/gym_pusht/envs/pusht.py#L522-L534
# For completeness, the diagram is copied here:
# 0───────────1
# │ │
# 3───4───5───2
# │ │
# │ │
# │ │
# │ │
# 7───6
# Note: The original work trains keypoints-only with conditioning via inpainting. Here, we encode the
# observation along with the agent position and use the encoding as global conditioning for the denoising
# U-Net.
# Note: We do not track EMA model weights as we discovered it does not improve the results. See
# https://github.com/huggingface/lerobot/pull/134 for more details.
seed: 100000
dataset_repo_id: lerobot/pusht_keypoints
training:
offline_steps: 200000
online_steps: 0
eval_freq: 5000
save_freq: 5000
log_freq: 250
save_checkpoint: true
batch_size: 64
grad_clip_norm: 10
lr: 1.0e-4
lr_scheduler: cosine
lr_warmup_steps: 500
adam_betas: [0.95, 0.999]
adam_eps: 1.0e-8
adam_weight_decay: 1.0e-6
online_steps_between_rollouts: 1
delta_timestamps:
observation.environment_state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
# The original implementation doesn't sample frames for the last 7 steps,
# which avoids excessive padding and leads to improved training results.
drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
eval:
n_episodes: 50
batch_size: 50
policy:
name: diffusion
# Input / output structure.
n_obs_steps: 2
horizon: 16
n_action_steps: 8
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.environment_state: [16]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.environment_state: min_max
observation.state: min_max
output_normalization_modes:
action: min_max
# Architecture / modeling.
# Vision backbone.
vision_backbone: resnet18
crop_shape: [84, 84]
crop_is_random: True
pretrained_backbone_weights: null
use_group_norm: True
spatial_softmax_num_keypoints: 32
# Unet.
down_dims: [256, 512, 1024]
kernel_size: 5
n_groups: 8
diffusion_step_embed_dim: 128
use_film_scale_modulation: True
# Noise scheduler.
noise_scheduler_type: DDIM
num_train_timesteps: 100
beta_schedule: squaredcos_cap_v2
beta_start: 0.0001
beta_end: 0.02
prediction_type: epsilon # epsilon / sample
clip_sample: True
clip_sample_range: 1.0
# Inference
num_inference_steps: 10 # if not provided, defaults to `num_train_timesteps`
# Loss computation
do_mask_loss_for_padding: false