Merge remote-tracking branch 'upstream/main' into train_act

This commit is contained in:
Alexander Soare 2024-05-16 15:10:09 +01:00
commit b7b001c925
9 changed files with 107 additions and 69 deletions

View File

@ -152,10 +152,3 @@ class ACTConfig:
raise ValueError( raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
) )
# Check that there is only one image.
# TODO(alexander-soare): generalize this to multiple images.
if (
sum(k.startswith("observation.images.") for k in self.input_shapes) != 1
or "observation.images.top" not in self.input_shapes
):
raise ValueError('For now, only "observation.images.top" is accepted for an image input.')

View File

@ -62,6 +62,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
if config is None: if config is None:
config = ACTConfig() config = ACTConfig()
self.config: ACTConfig = config self.config: ACTConfig = config
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats config.input_shapes, config.input_normalization_modes, dataset_stats
) )
@ -71,8 +72,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self.unnormalize_outputs = Unnormalize( self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_shapes, config.output_normalization_modes, dataset_stats
) )
self.model = ACT(config) self.model = ACT(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.reset()
def reset(self): def reset(self):
"""This should be called whenever the environment is reset.""" """This should be called whenever the environment is reset."""
if self.config.temporal_ensemble_momentum is not None: if self.config.temporal_ensemble_momentum is not None:
@ -88,13 +94,10 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
environment. It works by managing the actions in a queue and only calling `select_actions` when the environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty. queue is empty.
""" """
assert "observation.images.top" in batch
assert "observation.state" in batch
self.eval() self.eval()
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
self._stack_images(batch) batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return # If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
# the first action. # the first action.
@ -129,8 +132,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
self._stack_images(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = ( l1_loss = (
@ -153,21 +156,6 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
return loss_dict return loss_dict
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
This function expects `batch` to have (at least):
{
"observation.state": (B, state_dim) batch of robot states.
"observation.images.{name}": (B, C, H, W) tensor of images.
}
"""
# Stack images in the order dictated by input_shapes.
batch["observation.images"] = torch.stack(
[batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")],
dim=-4,
)
class ACT(nn.Module): class ACT(nn.Module):
"""Action Chunking Transformer: The underlying neural network for ACTPolicy. """Action Chunking Transformer: The underlying neural network for ACTPolicy.
@ -197,10 +185,10 @@ class ACT(nn.Module):
encoder Transf. encoder Transf.
encoder encoder
inputs inputs image emb.
state emb.
""" """
@ -342,18 +330,18 @@ class ACT(nn.Module):
all_cam_features.append(cam_features) all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed) all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate camera observation feature maps and positional embeddings along the width dimension. # Concatenate camera observation feature maps and positional embeddings along the width dimension.
encoder_in = torch.cat(all_cam_features, axis=3) encoder_in = torch.cat(all_cam_features, axis=-1)
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3) cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
# Get positional embeddings for robot state and latent. # Get positional embeddings for robot state and latent.
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
latent_embed = self.encoder_latent_input_proj(latent_sample) latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
# Stack encoder input and positional embeddings moving to (S, B, C). # Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in = torch.cat( encoder_in = torch.cat(
[ [
torch.stack([latent_embed, robot_state_embed], axis=0), torch.stack([latent_embed, robot_state_embed], axis=0),
encoder_in.flatten(2).permute(2, 0, 1), einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
] ]
) )
pos_embed = torch.cat( pos_embed = torch.cat(

View File

@ -148,14 +148,21 @@ class DiffusionConfig:
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
) )
# There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) != 1:
raise ValueError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
image_key = next(iter(image_keys))
if ( if (
self.crop_shape[0] > self.input_shapes["observation.image"][1] self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes["observation.image"][2] or self.crop_shape[1] > self.input_shapes[image_key][2]
): ):
raise ValueError( raise ValueError(
f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} ' f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
f'for `crop_shape` and {self.input_shapes["observation.image"]} for ' f"for `crop_shape` and {self.input_shapes[image_key]} for "
'`input_shapes["observation.image"]`.' "`input_shapes[{image_key}]`."
) )
supported_prediction_types = ["epsilon", "sample"] supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types: if self.prediction_type not in supported_prediction_types:

View File

@ -19,6 +19,7 @@
TODO(alexander-soare): TODO(alexander-soare):
- Remove reliance on Robomimic for SpatialSoftmax. - Remove reliance on Robomimic for SpatialSoftmax.
- Remove reliance on diffusers for DDPMScheduler and LR scheduler. - Remove reliance on diffusers for DDPMScheduler and LR scheduler.
- Make compatible with multiple image keys.
""" """
import math import math
@ -83,10 +84,18 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
self.diffusion = DiffusionModel(config) self.diffusion = DiffusionModel(config)
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
if len(image_keys) != 1:
raise NotImplementedError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
self.input_image_key = image_keys[0]
self.reset()
def reset(self): def reset(self):
""" """Clear observation and action queues. Should be called on `env.reset()`"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._queues = { self._queues = {
"observation.image": deque(maxlen=self.config.n_obs_steps), "observation.image": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps), "observation.state": deque(maxlen=self.config.n_obs_steps),
@ -115,16 +124,14 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
"horizon" may not the best name to describe what the variable actually means, because this period is "horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
""" """
assert "observation.image" in batch
assert "observation.state" in batch
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0: if len(self._queues["action"]) == 0:
# stack n latest observations from the queue # stack n latest observations from the queue
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch) actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary? # TODO(rcadene): make above methods return output dictionary?
@ -138,6 +145,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch) loss = self.diffusion.compute_loss(batch)
return {"loss": loss} return {"loss": loss}
@ -215,13 +223,12 @@ class DiffusionModel(nn.Module):
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
""" """
This function expects `batch` to have (at least): This function expects `batch` to have:
{ {
"observation.state": (B, n_obs_steps, state_dim) "observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W) "observation.image": (B, n_obs_steps, C, H, W)
} }
""" """
assert set(batch).issuperset({"observation.state", "observation.image"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2] batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.config.n_obs_steps assert n_obs_steps == self.config.n_obs_steps
@ -345,9 +352,12 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers. # Set up pooling and final layers.
# Use a dry run to get the feature map shape. # Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.input_shapes` and it should use the # The dummy input should take the number of image channels from `config.input_shapes` and it should
# height and width from `config.crop_shape`. # use the height and width from `config.crop_shape`.
dummy_input = torch.zeros(size=(1, config.input_shapes["observation.image"][0], *config.crop_shape)) image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
assert len(image_keys) == 1
image_key = image_keys[0]
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
with torch.inference_mode(): with torch.inference_mode():
dummy_feature_map = self.backbone(dummy_input) dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:]) feature_map_shape = tuple(dummy_feature_map.shape[1:])

View File

@ -147,12 +147,18 @@ class TDMPCConfig:
def __post_init__(self): def __post_init__(self):
"""Input validation (not exhaustive).""" """Input validation (not exhaustive)."""
if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]: # There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) != 1:
raise ValueError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
image_key = next(iter(image_keys))
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift # TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed. # augmentation. It should be able to be removed.
raise ValueError( raise ValueError(
"Only square images are handled now. Got image shape " f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
f"{self.input_shapes['observation.image']}."
) )
if self.n_gaussian_samples <= 0: if self.n_gaussian_samples <= 0:
raise ValueError( raise ValueError(

View File

@ -112,13 +112,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_shapes, config.output_normalization_modes, dataset_stats
) )
def save(self, fp): image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
"""Save state dict of TOLD model to filepath.""" # Note: This check is covered in the post-init of the config but have a sanity check just in case.
torch.save(self.state_dict(), fp) assert len(image_keys) == 1
self.input_image_key = image_keys[0]
def load(self, fp): self.reset()
"""Load a saved state dict from filepath into current agent."""
self.load_state_dict(torch.load(fp))
def reset(self): def reset(self):
""" """
@ -137,10 +136,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]): def select_action(self, batch: dict[str, Tensor]):
"""Select a single action given environment observations.""" """Select a single action given environment observations."""
assert "observation.image" in batch
assert "observation.state" in batch
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
@ -319,13 +316,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
device = get_device_from_parameters(self) device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
info = {} info = {}
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
batch_size = batch["index"].shape[0]
# (b, t) -> (t, b) # (b, t) -> (t, b)
for key in batch: for key in batch:
if batch[key].ndim > 1: if batch[key].ndim > 1:
@ -353,6 +348,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# Run latent rollout using the latent dynamics model and policy model. # Run latent rollout using the latent dynamics model and policy model.
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`. # gives us a next `z`.
batch_size = batch["index"].shape[0]
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device) z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds[0] = self.model.encode(current_observation) z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device) reward_preds = torch.empty_like(reward, device=device)

View File

@ -19,6 +19,10 @@ from torch import nn
def populate_queues(queues, batch): def populate_queues(queues, batch):
for key in batch: for key in batch:
# Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the
# queues have the keys they want).
if key not in queues:
continue
if len(queues[key]) != queues[key].maxlen: if len(queues[key]) != queues[key].maxlen:
# initialize by copying the first observation several times until the queue is full # initialize by copying the first observation several times until the queue is full
while len(queues[key]) != queues[key].maxlen: while len(queues[key]) != queues[key].maxlen:

View File

@ -23,6 +23,7 @@ import hydra
import torch import torch
from datasets import concatenate_datasets from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars from datasets.utils import disable_progress_bars, enable_progress_bars
from omegaconf import DictConfig
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
@ -307,7 +308,7 @@ def add_episodes_inplace(
sampler.num_samples = len(concat_dataset) sampler.num_samples = len(concat_dataset)
def train(cfg: dict, out_dir=None, job_name=None): def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
if job_name is None: if job_name is None:

View File

@ -64,6 +64,14 @@ def test_get_policy_and_config_classes(policy_name: str):
"act", "act",
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"], ["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
), ),
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
(
"aloha",
"diffusion",
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"],
),
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
], ],
) )
@require_env @require_env
@ -87,6 +95,31 @@ def test_policy(env_name, policy_name, extra_overrides):
+ extra_overrides, + extra_overrides,
) )
# Additional config override logic.
if env_name == "aloha" and policy_name == "diffusion":
for keys in [
("training", "delta_timestamps"),
("policy", "input_shapes"),
("policy", "input_normalization_modes"),
]:
dct = dict(cfg[keys[0]][keys[1]])
dct["observation.images.top"] = dct["observation.image"]
del dct["observation.image"]
cfg[keys[0]][keys[1]] = dct
cfg.override_dataset_stats = None
# Additional config override logic.
if env_name == "pusht" and policy_name == "act":
for keys in [
("policy", "input_shapes"),
("policy", "input_normalization_modes"),
]:
dct = dict(cfg[keys[0]][keys[1]])
dct["observation.image"] = dct["observation.images.top"]
del dct["observation.images.top"]
cfg[keys[0]][keys[1]] = dct
cfg.override_dataset_stats = None
# Check that we can make the policy object. # Check that we can make the policy object.
dataset = make_dataset(cfg) dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)