backup wip

This commit is contained in:
Alexander Soare 2024-05-08 18:43:28 +01:00
parent 4b4f922fa7
commit e80fc1d7eb
4 changed files with 65 additions and 24 deletions

View File

@ -67,6 +67,25 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
if self.config.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.config.n_action_steps)
def _check_and_preprocess_batch(
self, batch: dict[str, Tensor], train_mode: bool = False
) -> dict[str, Tensor]:
"""Check that the keys can be handled by this policy and stack all images into one tensor.
This should be run after input normalization.
"""
batch = dict(batch) # shallow copy
assert "observation.state" in batch
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
assert image_keys == set(
self.expected_image_keys
), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}."
if train_mode:
assert "action" in batch
assert "action_is_pad" in batch
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
return batch
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
@ -78,7 +97,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self.eval()
batch = self.normalize_inputs(batch)
self._check_and_preprocess_batch(batch)
batch = self._check_and_preprocess_batch(batch)
if len(self._action_queue) == 0:
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
@ -95,7 +114,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
self._check_and_preprocess_batch(batch, train_mode=True)
batch = self._check_and_preprocess_batch(batch, train_mode=True)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
@ -118,21 +137,6 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
return loss_dict
def _check_and_preprocess_batch(self, batch: dict[str, Tensor], train_mode: bool = False):
"""Check that the keys can be handled by this policy and stack all images into one tensor.
This should be run after input normalization.
"""
assert "observation.state" in batch
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
assert image_keys == set(
self.expected_image_keys
), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}."
if train_mode:
assert "action" in batch
assert "action_is_pad" in batch
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
class ACT(nn.Module):
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.

View File

@ -77,11 +77,14 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
"action": deque(maxlen=self.config.n_action_steps),
}
def _check_and_preprocess_batch_keys(self, batch: dict[str, Tensor], train_mode: bool = False):
"""Check that the keys can be handled by this policy and standardize the image key.
def _check_and_preprocess_batch_keys(
self, batch: dict[str, Tensor], train_mode: bool = False
) -> dict[str, Tensor]:
"""Check that the keys can be handled by this policy and standardizes the image key.
This should be run after input normalization.
"""
batch = dict(batch) # shallow copy
assert "observation.state" in batch
# There should only be one image key.
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
@ -95,6 +98,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
if image_key != "observation.image":
batch["observation.image"] = batch[image_key]
del batch[image_key]
return batch
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@ -119,7 +123,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
batch = self.normalize_inputs(batch)
self._check_and_preprocess_batch_keys(batch)
batch = self._check_and_preprocess_batch_keys(batch)
self._queues = populate_queues(self._queues, batch)
@ -139,7 +143,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
self._check_and_preprocess_batch_keys(batch, train_mode=True)
batch = self._check_and_preprocess_batch_keys(batch, train_mode=True)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}

View File

@ -131,12 +131,18 @@ class TDMPCConfig:
def __post_init__(self):
"""Input validation (not exhaustive)."""
if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]:
# There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) != 1:
raise ValueError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
image_key = next(iter(image_keys))
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed.
raise ValueError(
"Only square images are handled now. Got image shape "
f"{self.input_shapes['observation.image']}."
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
)
if self.n_gaussian_samples <= 0:
raise ValueError(

View File

@ -96,6 +96,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
def save(self, fp):
"""Save state dict of TOLD model to filepath."""
torch.save(self.state_dict(), fp)
@ -118,6 +120,29 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# CEM for the next step.
self._prev_mean: torch.Tensor | None = None
def _check_and_preprocess_batch_keys(
self, batch: dict[str, Tensor], train_mode: bool = False
) -> dict[str, Tensor]:
"""Check that the keys can be handled by this policy and standardizes the image key.
This should be run after input normalization.
"""
batch = dict(batch) # shallow copy
assert "observation.state" in batch
# There should only be one image key.
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
assert image_keys == set(
self.expected_image_keys
), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}."
if train_mode:
assert "action" in batch
assert "action_is_pad" in batch
image_key = next(iter(image_keys))
if image_key != "observation.image":
batch["observation.image"] = batch[image_key]
del batch[image_key]
return batch
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]):
"""Select a single action given environment observations."""
@ -125,6 +150,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
assert "observation.state" in batch
batch = self.normalize_inputs(batch)
batch = self._check_and_preprocess_batch_keys(batch)
self._queues = populate_queues(self._queues, batch)
@ -303,6 +329,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
batch = self._check_and_preprocess_batch_keys(batch, train_mode=True)
batch = self.normalize_targets(batch)
info = {}