ready for review

This commit is contained in:
Alexander Soare 2024-05-09 11:55:10 +01:00
parent 158627d07e
commit 4cfebf1f0a
5 changed files with 58 additions and 89 deletions

View File

@ -67,25 +67,6 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
if self.config.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.config.n_action_steps)
def _check_and_preprocess_batch(
self, batch: dict[str, Tensor], train_mode: bool = False
) -> dict[str, Tensor]:
"""Check that the keys can be handled by this policy and stack all images into one tensor.
This should be run after input normalization.
"""
batch = dict(batch) # shallow copy
assert "observation.state" in batch
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
assert image_keys == set(
self.expected_image_keys
), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}."
if train_mode:
assert "action" in batch
assert "action_is_pad" in batch
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
return batch
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
@ -97,7 +78,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self.eval()
batch = self.normalize_inputs(batch)
batch = self._check_and_preprocess_batch(batch)
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
if len(self._action_queue) == 0:
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
@ -113,8 +94,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch)
batch = self._check_and_preprocess_batch(batch, train_mode=True)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (

View File

@ -67,7 +67,12 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
self.diffusion = DiffusionModel(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
assert len(image_keys) == 1
self.input_image_key = image_keys[0]
self.reset()
def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`"""
@ -77,29 +82,6 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
"action": deque(maxlen=self.config.n_action_steps),
}
def _check_and_preprocess_batch_keys(
self, batch: dict[str, Tensor], train_mode: bool = False
) -> dict[str, Tensor]:
"""Check that the keys can be handled by this policy and standardizes the image key.
This should be run after input normalization.
"""
batch = dict(batch) # shallow copy
assert "observation.state" in batch
# There should only be one image key.
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
assert image_keys == set(
self.expected_image_keys
), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}."
if train_mode:
assert "action" in batch
assert "action_is_pad" in batch
image_key = next(iter(image_keys))
if image_key != "observation.image":
batch["observation.image"] = batch[image_key]
del batch[image_key]
return batch
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
@ -123,13 +105,13 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
batch = self.normalize_inputs(batch)
batch = self._check_and_preprocess_batch_keys(batch)
batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
@ -143,7 +125,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = self._check_and_preprocess_batch_keys(batch, train_mode=True)
batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
@ -352,9 +334,9 @@ class DiffusionRgbEncoder(nn.Module):
# Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.input_shapes` and it should
# use the height and width from `config.crop_shape`.
image_keys = {k for k in config.input_shapes if k.startswith("observation.image")}
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
assert len(image_keys) == 1
image_key = next(iter(image_keys))
image_key = image_keys[0]
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
with torch.inference_mode():
dummy_feature_map = self.backbone(dummy_input)

View File

@ -96,15 +96,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
assert len(image_keys) == 1
self.input_image_key = image_keys[0]
def save(self, fp):
"""Save state dict of TOLD model to filepath."""
torch.save(self.state_dict(), fp)
def load(self, fp):
"""Load a saved state dict from filepath into current agent."""
self.load_state_dict(torch.load(fp))
self.reset()
def reset(self):
"""
@ -120,37 +117,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# CEM for the next step.
self._prev_mean: torch.Tensor | None = None
def _check_and_preprocess_batch_keys(
self, batch: dict[str, Tensor], train_mode: bool = False
) -> dict[str, Tensor]:
"""Check that the keys can be handled by this policy and standardizes the image key.
This should be run after input normalization.
"""
batch = dict(batch) # shallow copy
assert "observation.state" in batch
# There should only be one image key.
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
assert image_keys == set(
self.expected_image_keys
), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}."
if train_mode:
assert "action" in batch
assert "action_is_pad" in batch
image_key = next(iter(image_keys))
if image_key != "observation.image":
batch["observation.image"] = batch[image_key]
del batch[image_key]
return batch
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]):
"""Select a single action given environment observations."""
assert "observation.image" in batch
assert "observation.state" in batch
batch = self.normalize_inputs(batch)
batch = self._check_and_preprocess_batch_keys(batch)
batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch)
@ -329,14 +300,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
batch = self._check_and_preprocess_batch_keys(batch, train_mode=True)
batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch)
info = {}
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
batch_size = batch["index"].shape[0]
# (b, t) -> (t, b)
for key in batch:
if batch[key].ndim > 1:
@ -364,6 +332,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# Run latent rollout using the latent dynamics model and policy model.
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`.
batch_size = batch["index"].shape[0]
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device)

View File

@ -4,6 +4,10 @@ from torch import nn
def populate_queues(queues, batch):
for key in batch:
# Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the
# queues have the keys they want).
if key not in queues:
continue
if len(queues[key]) != queues[key].maxlen:
# initialize by copying the first observation several times until the queue is full
while len(queues[key]) != queues[key].maxlen:

View File

@ -49,6 +49,14 @@ def test_get_policy_and_config_classes(policy_name: str):
"act",
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
),
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
(
"aloha",
"diffusion",
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"],
),
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
],
)
@require_env
@ -72,6 +80,31 @@ def test_policy(env_name, policy_name, extra_overrides):
+ extra_overrides,
)
# Additional config override logic.
if env_name == "aloha" and policy_name == "diffusion":
for keys in [
("training", "delta_timestamps"),
("policy", "input_shapes"),
("policy", "input_normalization_modes"),
]:
dct = dict(cfg[keys[0]][keys[1]])
dct["observation.images.top"] = dct["observation.image"]
del dct["observation.image"]
cfg[keys[0]][keys[1]] = dct
cfg.override_dataset_stats = None
# Additional config override logic.
if env_name == "pusht" and policy_name == "act":
for keys in [
("policy", "input_shapes"),
("policy", "input_normalization_modes"),
]:
dct = dict(cfg[keys[0]][keys[1]])
dct["observation.image"] = dct["observation.images.top"]
del dct["observation.images.top"]
cfg[keys[0]][keys[1]] = dct
cfg.training.override_dataset_stats = None
# Check that we can make the policy object.
dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)