ready for review
This commit is contained in:
parent
158627d07e
commit
4cfebf1f0a
|
@ -67,25 +67,6 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
if self.config.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def _check_and_preprocess_batch(
|
||||
self, batch: dict[str, Tensor], train_mode: bool = False
|
||||
) -> dict[str, Tensor]:
|
||||
"""Check that the keys can be handled by this policy and stack all images into one tensor.
|
||||
|
||||
This should be run after input normalization.
|
||||
"""
|
||||
batch = dict(batch) # shallow copy
|
||||
assert "observation.state" in batch
|
||||
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
|
||||
assert image_keys == set(
|
||||
self.expected_image_keys
|
||||
), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}."
|
||||
if train_mode:
|
||||
assert "action" in batch
|
||||
assert "action_is_pad" in batch
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
return batch
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
@ -97,7 +78,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self._check_and_preprocess_batch(batch)
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
|
@ -113,8 +94,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
batch = self._check_and_preprocess_batch(batch, train_mode=True)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
|
|
|
@ -67,7 +67,12 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
|
||||
self.diffusion = DiffusionModel(config)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||
assert len(image_keys) == 1
|
||||
self.input_image_key = image_keys[0]
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
|
@ -77,29 +82,6 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def _check_and_preprocess_batch_keys(
|
||||
self, batch: dict[str, Tensor], train_mode: bool = False
|
||||
) -> dict[str, Tensor]:
|
||||
"""Check that the keys can be handled by this policy and standardizes the image key.
|
||||
|
||||
This should be run after input normalization.
|
||||
"""
|
||||
batch = dict(batch) # shallow copy
|
||||
assert "observation.state" in batch
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
|
||||
assert image_keys == set(
|
||||
self.expected_image_keys
|
||||
), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}."
|
||||
if train_mode:
|
||||
assert "action" in batch
|
||||
assert "action_is_pad" in batch
|
||||
image_key = next(iter(image_keys))
|
||||
if image_key != "observation.image":
|
||||
batch["observation.image"] = batch[image_key]
|
||||
del batch[image_key]
|
||||
return batch
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
@ -123,13 +105,13 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self._check_and_preprocess_batch_keys(batch)
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
|
@ -143,7 +125,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self._check_and_preprocess_batch_keys(batch, train_mode=True)
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
|
@ -352,9 +334,9 @@ class DiffusionRgbEncoder(nn.Module):
|
|||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||
# use the height and width from `config.crop_shape`.
|
||||
image_keys = {k for k in config.input_shapes if k.startswith("observation.image")}
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
assert len(image_keys) == 1
|
||||
image_key = next(iter(image_keys))
|
||||
image_key = image_keys[0]
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
|
||||
with torch.inference_mode():
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
|
|
|
@ -96,15 +96,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||
assert len(image_keys) == 1
|
||||
self.input_image_key = image_keys[0]
|
||||
|
||||
def save(self, fp):
|
||||
"""Save state dict of TOLD model to filepath."""
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
"""Load a saved state dict from filepath into current agent."""
|
||||
self.load_state_dict(torch.load(fp))
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
|
@ -120,37 +117,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
# CEM for the next step.
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
|
||||
def _check_and_preprocess_batch_keys(
|
||||
self, batch: dict[str, Tensor], train_mode: bool = False
|
||||
) -> dict[str, Tensor]:
|
||||
"""Check that the keys can be handled by this policy and standardizes the image key.
|
||||
|
||||
This should be run after input normalization.
|
||||
"""
|
||||
batch = dict(batch) # shallow copy
|
||||
assert "observation.state" in batch
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
|
||||
assert image_keys == set(
|
||||
self.expected_image_keys
|
||||
), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}."
|
||||
if train_mode:
|
||||
assert "action" in batch
|
||||
assert "action_is_pad" in batch
|
||||
image_key = next(iter(image_keys))
|
||||
if image_key != "observation.image":
|
||||
batch["observation.image"] = batch[image_key]
|
||||
del batch[image_key]
|
||||
return batch
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
"""Select a single action given environment observations."""
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self._check_and_preprocess_batch_keys(batch)
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
|
@ -329,14 +300,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self._check_and_preprocess_batch_keys(batch, train_mode=True)
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
|
||||
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
|
||||
batch_size = batch["index"].shape[0]
|
||||
|
||||
# (b, t) -> (t, b)
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
|
@ -364,6 +332,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||
# gives us a next `z`.
|
||||
batch_size = batch["index"].shape[0]
|
||||
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||
z_preds[0] = self.model.encode(current_observation)
|
||||
reward_preds = torch.empty_like(reward, device=device)
|
||||
|
|
|
@ -4,6 +4,10 @@ from torch import nn
|
|||
|
||||
def populate_queues(queues, batch):
|
||||
for key in batch:
|
||||
# Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the
|
||||
# queues have the keys they want).
|
||||
if key not in queues:
|
||||
continue
|
||||
if len(queues[key]) != queues[key].maxlen:
|
||||
# initialize by copying the first observation several times until the queue is full
|
||||
while len(queues[key]) != queues[key].maxlen:
|
||||
|
|
|
@ -49,6 +49,14 @@ def test_get_policy_and_config_classes(policy_name: str):
|
|||
"act",
|
||||
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
|
||||
),
|
||||
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||
(
|
||||
"aloha",
|
||||
"diffusion",
|
||||
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"],
|
||||
),
|
||||
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
||||
],
|
||||
)
|
||||
@require_env
|
||||
|
@ -72,6 +80,31 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
+ extra_overrides,
|
||||
)
|
||||
|
||||
# Additional config override logic.
|
||||
if env_name == "aloha" and policy_name == "diffusion":
|
||||
for keys in [
|
||||
("training", "delta_timestamps"),
|
||||
("policy", "input_shapes"),
|
||||
("policy", "input_normalization_modes"),
|
||||
]:
|
||||
dct = dict(cfg[keys[0]][keys[1]])
|
||||
dct["observation.images.top"] = dct["observation.image"]
|
||||
del dct["observation.image"]
|
||||
cfg[keys[0]][keys[1]] = dct
|
||||
cfg.override_dataset_stats = None
|
||||
|
||||
# Additional config override logic.
|
||||
if env_name == "pusht" and policy_name == "act":
|
||||
for keys in [
|
||||
("policy", "input_shapes"),
|
||||
("policy", "input_normalization_modes"),
|
||||
]:
|
||||
dct = dict(cfg[keys[0]][keys[1]])
|
||||
dct["observation.image"] = dct["observation.images.top"]
|
||||
del dct["observation.images.top"]
|
||||
cfg[keys[0]][keys[1]] = dct
|
||||
cfg.training.override_dataset_stats = None
|
||||
|
||||
# Check that we can make the policy object.
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
|
|
Loading…
Reference in New Issue