Train diffusion pusht_keypoints (#307)

Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
Alexander Soare 2024-07-09 12:35:50 +01:00 committed by GitHub
parent a4d77b99f0
commit cc2f6e7404
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 206 additions and 56 deletions

View File

@ -28,7 +28,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
"""
# map to expected inputs for the policy
return_observations = {}
if "pixels" in observations:
if isinstance(observations["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
else:
@ -51,8 +51,12 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return_observations[imgkey] = img
if "environment_state" in observations:
return_observations["observation.environment_state"] = torch.from_numpy(
observations["environment_state"]
).float()
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
# requirement for "agent_pos"
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
return return_observations

View File

@ -28,7 +28,10 @@ class DiffusionConfig:
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- Either:
- At least one key starting with "observation.image is required as an input.
AND/OR
- The key "observation.environment_state" is required as input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views. Right now we only support all images having the same shape.
- "action" is required as an output key.
@ -155,7 +158,13 @@ class DiffusionConfig:
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) == 0 and "observation.environment_state" not in self.input_shapes:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if len(image_keys) > 0:
if self.crop_shape is not None:
for image_key in image_keys:
if (
@ -175,6 +184,7 @@ class DiffusionConfig:
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
"expect all image shapes to match."
)
supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types:
raise ValueError(

View File

@ -83,16 +83,20 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
self.diffusion = DiffusionModel(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.use_env_state = "observation.environment_state" in config.input_shapes
self.reset()
def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
"observation.images": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
}
if len(self.expected_image_keys) > 0:
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
if self.use_env_state:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@ -117,6 +121,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
@ -137,6 +142,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
@ -161,15 +167,20 @@ class DiffusionModel(nn.Module):
super().__init__()
self.config = config
self.rgb_encoder = DiffusionRgbEncoder(config)
# Build observation encoders (depending on which observations are provided).
global_cond_dim = config.input_shapes["observation.state"][0]
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
self.unet = DiffusionConditionalUnet1d(
config,
global_cond_dim=(
config.input_shapes["observation.state"][0] + self.rgb_encoder.feature_dim * num_images
)
* config.n_obs_steps,
)
self._use_images = False
self._use_env_state = False
if num_images > 0:
self._use_images = True
self.rgb_encoder = DiffusionRgbEncoder(config)
global_cond_dim += self.rgb_encoder.feature_dim * num_images
if "observation.environment_state" in config.input_shapes:
self._use_env_state = True
global_cond_dim += config.input_shapes["observation.environment_state"][0]
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
@ -219,24 +230,34 @@ class DiffusionModel(nn.Module):
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
"""Encode image features and concatenate them all together along with the state vector."""
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
global_cond_feats = [batch["observation.state"]]
# Extract image feature (first combine batch, sequence, and camera index dims).
if self._use_images:
img_features = self.rgb_encoder(
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
)
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the feature
# dim (effectively concatenating the camera features).
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
)
# Concatenate state and image features then flatten to (B, global_cond_dim).
return torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
global_cond_feats.append(img_features)
if self._use_env_state:
global_cond_feats.append(batch["observation.environment_state"])
# Concatenate features then flatten to (B, global_cond_dim).
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
"""
This function expects `batch` to have:
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
AND/OR
"observation.environment_state": (B, environment_dim)
}
"""
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
@ -260,13 +281,18 @@ class DiffusionModel(nn.Module):
This function expects `batch` to have (at least):
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
AND/OR
"observation.environment_state": (B, environment_dim)
"action": (B, horizon, action_dim)
"action_is_pad": (B, horizon)
}
"""
# Input validation.
assert set(batch).issuperset({"observation.state", "observation.images", "action", "action_is_pad"})
assert set(batch).issuperset({"observation.state", "action", "action_is_pad"})
assert "observation.images" in batch or "observation.environment_state" in batch
n_obs_steps = batch["observation.state"].shape[1]
horizon = batch["action"].shape[1]
assert horizon == self.config.horizon

View File

@ -0,0 +1,110 @@
# @package _global_
# Defaults for training for the pusht_keypoints dataset.
# They keypoints are on the vertices of the rectangles that make up the PushT as documented in the PushT
# environment:
# https://github.com/huggingface/gym-pusht/blob/5e2489be9ff99ed9cd47b6c653dda3b7aa844d24/gym_pusht/envs/pusht.py#L522-L534
# For completeness, the diagram is copied here:
# 0───────────1
# │ │
# 3───4───5───2
# │ │
# │ │
# │ │
# │ │
# 7───6
# Note: The original work trains keypoints-only with conditioning via inpainting. Here, we encode the
# observation along with the agent position and use the encoding as global conditioning for the denoising
# U-Net.
# Note: We do not track EMA model weights as we discovered it does not improve the results. See
# https://github.com/huggingface/lerobot/pull/134 for more details.
seed: 100000
dataset_repo_id: lerobot/pusht_keypoints
training:
offline_steps: 200000
online_steps: 0
eval_freq: 5000
save_freq: 5000
log_freq: 250
save_checkpoint: true
batch_size: 64
grad_clip_norm: 10
lr: 1.0e-4
lr_scheduler: cosine
lr_warmup_steps: 500
adam_betas: [0.95, 0.999]
adam_eps: 1.0e-8
adam_weight_decay: 1.0e-6
online_steps_between_rollouts: 1
delta_timestamps:
observation.environment_state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
# The original implementation doesn't sample frames for the last 7 steps,
# which avoids excessive padding and leads to improved training results.
drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
eval:
n_episodes: 50
batch_size: 50
policy:
name: diffusion
# Input / output structure.
n_obs_steps: 2
horizon: 16
n_action_steps: 8
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.environment_state: [16]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.environment_state: min_max
observation.state: min_max
output_normalization_modes:
action: min_max
# Architecture / modeling.
# Vision backbone.
vision_backbone: resnet18
crop_shape: [84, 84]
crop_is_random: True
pretrained_backbone_weights: null
use_group_norm: True
spatial_softmax_num_keypoints: 32
# Unet.
down_dims: [256, 512, 1024]
kernel_size: 5
n_groups: 8
diffusion_step_embed_dim: 128
use_film_scale_modulation: True
# Noise scheduler.
noise_scheduler_type: DDIM
num_train_timesteps: 100
beta_schedule: squaredcos_cap_v2
beta_start: 0.0001
beta_end: 0.02
prediction_type: epsilon # epsilon / sample
clip_sample: True
clip_sample_range: 1.0
# Inference
num_inference_steps: 10 # if not provided, defaults to `num_train_timesteps`
# Loss computation
do_mask_loss_for_padding: false