Train diffusion pusht_keypoints (#307)
Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
parent
a4d77b99f0
commit
cc2f6e7404
|
@ -28,7 +28,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||||
"""
|
"""
|
||||||
# map to expected inputs for the policy
|
# map to expected inputs for the policy
|
||||||
return_observations = {}
|
return_observations = {}
|
||||||
|
if "pixels" in observations:
|
||||||
if isinstance(observations["pixels"], dict):
|
if isinstance(observations["pixels"], dict):
|
||||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||||
else:
|
else:
|
||||||
|
@ -51,8 +51,12 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||||
|
|
||||||
return_observations[imgkey] = img
|
return_observations[imgkey] = img
|
||||||
|
|
||||||
|
if "environment_state" in observations:
|
||||||
|
return_observations["observation.environment_state"] = torch.from_numpy(
|
||||||
|
observations["environment_state"]
|
||||||
|
).float()
|
||||||
|
|
||||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||||
# requirement for "agent_pos"
|
# requirement for "agent_pos"
|
||||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||||
|
|
||||||
return return_observations
|
return return_observations
|
||||||
|
|
|
@ -28,7 +28,10 @@ class DiffusionConfig:
|
||||||
|
|
||||||
Notes on the inputs and outputs:
|
Notes on the inputs and outputs:
|
||||||
- "observation.state" is required as an input key.
|
- "observation.state" is required as an input key.
|
||||||
|
- Either:
|
||||||
- At least one key starting with "observation.image is required as an input.
|
- At least one key starting with "observation.image is required as an input.
|
||||||
|
AND/OR
|
||||||
|
- The key "observation.environment_state" is required as input.
|
||||||
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
||||||
views. Right now we only support all images having the same shape.
|
views. Right now we only support all images having the same shape.
|
||||||
- "action" is required as an output key.
|
- "action" is required as an output key.
|
||||||
|
@ -155,7 +158,13 @@ class DiffusionConfig:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
)
|
)
|
||||||
|
|
||||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||||
|
|
||||||
|
if len(image_keys) == 0 and "observation.environment_state" not in self.input_shapes:
|
||||||
|
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||||
|
|
||||||
|
if len(image_keys) > 0:
|
||||||
if self.crop_shape is not None:
|
if self.crop_shape is not None:
|
||||||
for image_key in image_keys:
|
for image_key in image_keys:
|
||||||
if (
|
if (
|
||||||
|
@ -175,6 +184,7 @@ class DiffusionConfig:
|
||||||
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
|
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
|
||||||
"expect all image shapes to match."
|
"expect all image shapes to match."
|
||||||
)
|
)
|
||||||
|
|
||||||
supported_prediction_types = ["epsilon", "sample"]
|
supported_prediction_types = ["epsilon", "sample"]
|
||||||
if self.prediction_type not in supported_prediction_types:
|
if self.prediction_type not in supported_prediction_types:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -83,16 +83,20 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
self.diffusion = DiffusionModel(config)
|
self.diffusion = DiffusionModel(config)
|
||||||
|
|
||||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
|
self.use_env_state = "observation.environment_state" in config.input_shapes
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||||
self._queues = {
|
self._queues = {
|
||||||
"observation.images": deque(maxlen=self.config.n_obs_steps),
|
|
||||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||||
"action": deque(maxlen=self.config.n_action_steps),
|
"action": deque(maxlen=self.config.n_action_steps),
|
||||||
}
|
}
|
||||||
|
if len(self.expected_image_keys) > 0:
|
||||||
|
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||||
|
if self.use_env_state:
|
||||||
|
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
@ -117,6 +121,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||||
"""
|
"""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
if len(self.expected_image_keys) > 0:
|
||||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
# Note: It's important that this happens after stacking the images into a single key.
|
# Note: It's important that this happens after stacking the images into a single key.
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
@ -137,6 +142,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
if len(self.expected_image_keys) > 0:
|
||||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
loss = self.diffusion.compute_loss(batch)
|
loss = self.diffusion.compute_loss(batch)
|
||||||
|
@ -161,15 +167,20 @@ class DiffusionModel(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
# Build observation encoders (depending on which observations are provided).
|
||||||
|
global_cond_dim = config.input_shapes["observation.state"][0]
|
||||||
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
|
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
|
||||||
self.unet = DiffusionConditionalUnet1d(
|
self._use_images = False
|
||||||
config,
|
self._use_env_state = False
|
||||||
global_cond_dim=(
|
if num_images > 0:
|
||||||
config.input_shapes["observation.state"][0] + self.rgb_encoder.feature_dim * num_images
|
self._use_images = True
|
||||||
)
|
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||||
* config.n_obs_steps,
|
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
||||||
)
|
if "observation.environment_state" in config.input_shapes:
|
||||||
|
self._use_env_state = True
|
||||||
|
global_cond_dim += config.input_shapes["observation.environment_state"][0]
|
||||||
|
|
||||||
|
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||||
|
|
||||||
self.noise_scheduler = _make_noise_scheduler(
|
self.noise_scheduler = _make_noise_scheduler(
|
||||||
config.noise_scheduler_type,
|
config.noise_scheduler_type,
|
||||||
|
@ -219,24 +230,34 @@ class DiffusionModel(nn.Module):
|
||||||
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
|
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Encode image features and concatenate them all together along with the state vector."""
|
"""Encode image features and concatenate them all together along with the state vector."""
|
||||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||||
|
global_cond_feats = [batch["observation.state"]]
|
||||||
# Extract image feature (first combine batch, sequence, and camera index dims).
|
# Extract image feature (first combine batch, sequence, and camera index dims).
|
||||||
|
if self._use_images:
|
||||||
img_features = self.rgb_encoder(
|
img_features = self.rgb_encoder(
|
||||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||||
)
|
)
|
||||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the feature
|
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||||
# dim (effectively concatenating the camera features).
|
# feature dim (effectively concatenating the camera features).
|
||||||
img_features = einops.rearrange(
|
img_features = einops.rearrange(
|
||||||
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||||
)
|
)
|
||||||
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
global_cond_feats.append(img_features)
|
||||||
return torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
|
||||||
|
if self._use_env_state:
|
||||||
|
global_cond_feats.append(batch["observation.environment_state"])
|
||||||
|
|
||||||
|
# Concatenate features then flatten to (B, global_cond_dim).
|
||||||
|
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
||||||
|
|
||||||
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""
|
"""
|
||||||
This function expects `batch` to have:
|
This function expects `batch` to have:
|
||||||
{
|
{
|
||||||
"observation.state": (B, n_obs_steps, state_dim)
|
"observation.state": (B, n_obs_steps, state_dim)
|
||||||
|
|
||||||
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||||
|
AND/OR
|
||||||
|
"observation.environment_state": (B, environment_dim)
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||||
|
@ -260,13 +281,18 @@ class DiffusionModel(nn.Module):
|
||||||
This function expects `batch` to have (at least):
|
This function expects `batch` to have (at least):
|
||||||
{
|
{
|
||||||
"observation.state": (B, n_obs_steps, state_dim)
|
"observation.state": (B, n_obs_steps, state_dim)
|
||||||
|
|
||||||
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||||
|
AND/OR
|
||||||
|
"observation.environment_state": (B, environment_dim)
|
||||||
|
|
||||||
"action": (B, horizon, action_dim)
|
"action": (B, horizon, action_dim)
|
||||||
"action_is_pad": (B, horizon)
|
"action_is_pad": (B, horizon)
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
# Input validation.
|
# Input validation.
|
||||||
assert set(batch).issuperset({"observation.state", "observation.images", "action", "action_is_pad"})
|
assert set(batch).issuperset({"observation.state", "action", "action_is_pad"})
|
||||||
|
assert "observation.images" in batch or "observation.environment_state" in batch
|
||||||
n_obs_steps = batch["observation.state"].shape[1]
|
n_obs_steps = batch["observation.state"].shape[1]
|
||||||
horizon = batch["action"].shape[1]
|
horizon = batch["action"].shape[1]
|
||||||
assert horizon == self.config.horizon
|
assert horizon == self.config.horizon
|
||||||
|
|
|
@ -0,0 +1,110 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# Defaults for training for the pusht_keypoints dataset.
|
||||||
|
|
||||||
|
# They keypoints are on the vertices of the rectangles that make up the PushT as documented in the PushT
|
||||||
|
# environment:
|
||||||
|
# https://github.com/huggingface/gym-pusht/blob/5e2489be9ff99ed9cd47b6c653dda3b7aa844d24/gym_pusht/envs/pusht.py#L522-L534
|
||||||
|
# For completeness, the diagram is copied here:
|
||||||
|
# 0───────────1
|
||||||
|
# │ │
|
||||||
|
# 3───4───5───2
|
||||||
|
# │ │
|
||||||
|
# │ │
|
||||||
|
# │ │
|
||||||
|
# │ │
|
||||||
|
# 7───6
|
||||||
|
|
||||||
|
|
||||||
|
# Note: The original work trains keypoints-only with conditioning via inpainting. Here, we encode the
|
||||||
|
# observation along with the agent position and use the encoding as global conditioning for the denoising
|
||||||
|
# U-Net.
|
||||||
|
|
||||||
|
# Note: We do not track EMA model weights as we discovered it does not improve the results. See
|
||||||
|
# https://github.com/huggingface/lerobot/pull/134 for more details.
|
||||||
|
|
||||||
|
seed: 100000
|
||||||
|
dataset_repo_id: lerobot/pusht_keypoints
|
||||||
|
|
||||||
|
training:
|
||||||
|
offline_steps: 200000
|
||||||
|
online_steps: 0
|
||||||
|
eval_freq: 5000
|
||||||
|
save_freq: 5000
|
||||||
|
log_freq: 250
|
||||||
|
save_checkpoint: true
|
||||||
|
|
||||||
|
batch_size: 64
|
||||||
|
grad_clip_norm: 10
|
||||||
|
lr: 1.0e-4
|
||||||
|
lr_scheduler: cosine
|
||||||
|
lr_warmup_steps: 500
|
||||||
|
adam_betas: [0.95, 0.999]
|
||||||
|
adam_eps: 1.0e-8
|
||||||
|
adam_weight_decay: 1.0e-6
|
||||||
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
observation.environment_state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||||
|
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||||
|
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
|
||||||
|
|
||||||
|
# The original implementation doesn't sample frames for the last 7 steps,
|
||||||
|
# which avoids excessive padding and leads to improved training results.
|
||||||
|
drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
|
||||||
|
|
||||||
|
eval:
|
||||||
|
n_episodes: 50
|
||||||
|
batch_size: 50
|
||||||
|
|
||||||
|
policy:
|
||||||
|
name: diffusion
|
||||||
|
|
||||||
|
# Input / output structure.
|
||||||
|
n_obs_steps: 2
|
||||||
|
horizon: 16
|
||||||
|
n_action_steps: 8
|
||||||
|
|
||||||
|
input_shapes:
|
||||||
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
observation.environment_state: [16]
|
||||||
|
observation.state: ["${env.state_dim}"]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes:
|
||||||
|
observation.environment_state: min_max
|
||||||
|
observation.state: min_max
|
||||||
|
output_normalization_modes:
|
||||||
|
action: min_max
|
||||||
|
|
||||||
|
# Architecture / modeling.
|
||||||
|
# Vision backbone.
|
||||||
|
vision_backbone: resnet18
|
||||||
|
crop_shape: [84, 84]
|
||||||
|
crop_is_random: True
|
||||||
|
pretrained_backbone_weights: null
|
||||||
|
use_group_norm: True
|
||||||
|
spatial_softmax_num_keypoints: 32
|
||||||
|
# Unet.
|
||||||
|
down_dims: [256, 512, 1024]
|
||||||
|
kernel_size: 5
|
||||||
|
n_groups: 8
|
||||||
|
diffusion_step_embed_dim: 128
|
||||||
|
use_film_scale_modulation: True
|
||||||
|
# Noise scheduler.
|
||||||
|
noise_scheduler_type: DDIM
|
||||||
|
num_train_timesteps: 100
|
||||||
|
beta_schedule: squaredcos_cap_v2
|
||||||
|
beta_start: 0.0001
|
||||||
|
beta_end: 0.02
|
||||||
|
prediction_type: epsilon # epsilon / sample
|
||||||
|
clip_sample: True
|
||||||
|
clip_sample_range: 1.0
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
num_inference_steps: 10 # if not provided, defaults to `num_train_timesteps`
|
||||||
|
|
||||||
|
# Loss computation
|
||||||
|
do_mask_loss_for_padding: false
|
Loading…
Reference in New Issue