backup wip
This commit is contained in:
parent
4b4f922fa7
commit
e80fc1d7eb
|
@ -67,6 +67,25 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
if self.config.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def _check_and_preprocess_batch(
|
||||
self, batch: dict[str, Tensor], train_mode: bool = False
|
||||
) -> dict[str, Tensor]:
|
||||
"""Check that the keys can be handled by this policy and stack all images into one tensor.
|
||||
|
||||
This should be run after input normalization.
|
||||
"""
|
||||
batch = dict(batch) # shallow copy
|
||||
assert "observation.state" in batch
|
||||
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
|
||||
assert image_keys == set(
|
||||
self.expected_image_keys
|
||||
), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}."
|
||||
if train_mode:
|
||||
assert "action" in batch
|
||||
assert "action_is_pad" in batch
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
return batch
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
@ -78,7 +97,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
self._check_and_preprocess_batch(batch)
|
||||
batch = self._check_and_preprocess_batch(batch)
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
|
@ -95,7 +114,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
self._check_and_preprocess_batch(batch, train_mode=True)
|
||||
batch = self._check_and_preprocess_batch(batch, train_mode=True)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
|
@ -118,21 +137,6 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
|
||||
return loss_dict
|
||||
|
||||
def _check_and_preprocess_batch(self, batch: dict[str, Tensor], train_mode: bool = False):
|
||||
"""Check that the keys can be handled by this policy and stack all images into one tensor.
|
||||
|
||||
This should be run after input normalization.
|
||||
"""
|
||||
assert "observation.state" in batch
|
||||
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
|
||||
assert image_keys == set(
|
||||
self.expected_image_keys
|
||||
), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}."
|
||||
if train_mode:
|
||||
assert "action" in batch
|
||||
assert "action_is_pad" in batch
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
|
||||
|
||||
class ACT(nn.Module):
|
||||
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
|
||||
|
|
|
@ -77,11 +77,14 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def _check_and_preprocess_batch_keys(self, batch: dict[str, Tensor], train_mode: bool = False):
|
||||
"""Check that the keys can be handled by this policy and standardize the image key.
|
||||
def _check_and_preprocess_batch_keys(
|
||||
self, batch: dict[str, Tensor], train_mode: bool = False
|
||||
) -> dict[str, Tensor]:
|
||||
"""Check that the keys can be handled by this policy and standardizes the image key.
|
||||
|
||||
This should be run after input normalization.
|
||||
"""
|
||||
batch = dict(batch) # shallow copy
|
||||
assert "observation.state" in batch
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
|
||||
|
@ -95,6 +98,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
if image_key != "observation.image":
|
||||
batch["observation.image"] = batch[image_key]
|
||||
del batch[image_key]
|
||||
return batch
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
|
@ -119,7 +123,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
self._check_and_preprocess_batch_keys(batch)
|
||||
batch = self._check_and_preprocess_batch_keys(batch)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
|
@ -139,7 +143,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
self._check_and_preprocess_batch_keys(batch, train_mode=True)
|
||||
batch = self._check_and_preprocess_batch_keys(batch, train_mode=True)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
|
|
|
@ -131,12 +131,18 @@ class TDMPCConfig:
|
|||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]:
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
if len(image_keys) != 1:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
image_key = next(iter(image_keys))
|
||||
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(
|
||||
"Only square images are handled now. Got image shape "
|
||||
f"{self.input_shapes['observation.image']}."
|
||||
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
|
||||
)
|
||||
if self.n_gaussian_samples <= 0:
|
||||
raise ValueError(
|
||||
|
|
|
@ -96,6 +96,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
|
||||
def save(self, fp):
|
||||
"""Save state dict of TOLD model to filepath."""
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
@ -118,6 +120,29 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
# CEM for the next step.
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
|
||||
def _check_and_preprocess_batch_keys(
|
||||
self, batch: dict[str, Tensor], train_mode: bool = False
|
||||
) -> dict[str, Tensor]:
|
||||
"""Check that the keys can be handled by this policy and standardizes the image key.
|
||||
|
||||
This should be run after input normalization.
|
||||
"""
|
||||
batch = dict(batch) # shallow copy
|
||||
assert "observation.state" in batch
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in batch if k.startswith("observation.image") and not k.endswith("_is_pad")}
|
||||
assert image_keys == set(
|
||||
self.expected_image_keys
|
||||
), f"Expected image keys: {self.expected_image_keys}. Got {image_keys}."
|
||||
if train_mode:
|
||||
assert "action" in batch
|
||||
assert "action_is_pad" in batch
|
||||
image_key = next(iter(image_keys))
|
||||
if image_key != "observation.image":
|
||||
batch["observation.image"] = batch[image_key]
|
||||
del batch[image_key]
|
||||
return batch
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
"""Select a single action given environment observations."""
|
||||
|
@ -125,6 +150,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
assert "observation.state" in batch
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self._check_and_preprocess_batch_keys(batch)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
|
@ -303,6 +329,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self._check_and_preprocess_batch_keys(batch, train_mode=True)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
|
|
Loading…
Reference in New Issue