From abbb1d2367eaf5e13268f8112e53e43c52038206 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 22 Jul 2024 20:38:33 +0100 Subject: [PATCH] Make sure policies don't mutate the batch (#323) --- lerobot/common/policies/act/modeling_act.py | 7 ++++--- .../policies/diffusion/modeling_diffusion.py | 2 ++ lerobot/common/policies/normalize.py | 2 ++ lerobot/common/policies/tdmpc/modeling_tdmpc.py | 2 ++ lerobot/common/policies/vqbet/modeling_vqbet.py | 2 ++ tests/test_policies.py | 17 +++++++++++++++-- 6 files changed, 27 insertions(+), 5 deletions(-) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index c072c31e..02691701 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -101,6 +101,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): batch = self.normalize_inputs(batch) if len(self.expected_image_keys) > 0: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) # If we are doing temporal ensembling, do online updates where we keep track of the number of actions @@ -128,6 +129,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) if len(self.expected_image_keys) > 0: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) @@ -467,10 +469,9 @@ class ACT(nn.Module): if self.use_images: all_cam_features = [] all_cam_pos_embeds = [] - images = batch["observation.images"] - for cam_index in range(images.shape[-4]): - cam_features = self.backbone(images[:, cam_index])["feature_map"] + for cam_index in range(batch["observation.images"].shape[-4]): + cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"] # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use # buffer cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index ec4039cc..0d7bab95 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -122,6 +122,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): """ batch = self.normalize_inputs(batch) if len(self.expected_image_keys) > 0: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) # Note: It's important that this happens after stacking the images into a single key. self._queues = populate_queues(self._queues, batch) @@ -143,6 +144,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) if len(self.expected_image_keys) > 0: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index 9b055f7e..f2e1179c 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -132,6 +132,7 @@ class Normalize(nn.Module): # TODO(rcadene): should we remove torch.no_grad? @torch.no_grad def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + batch = dict(batch) # shallow copy avoids mutating the input batch for key, mode in self.modes.items(): buffer = getattr(self, "buffer_" + key.replace(".", "_")) @@ -197,6 +198,7 @@ class Unnormalize(nn.Module): # TODO(rcadene): should we remove torch.no_grad? @torch.no_grad def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + batch = dict(batch) # shallow copy avoids mutating the input batch for key, mode in self.modes.items(): buffer = getattr(self, "buffer_" + key.replace(".", "_")) diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index de9658e9..020e48a2 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -137,6 +137,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations.""" batch = self.normalize_inputs(batch) + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch["observation.image"] = batch[self.input_image_key] self._queues = populate_queues(self._queues, batch) @@ -316,6 +317,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): device = get_device_from_parameters(self) batch = self.normalize_inputs(batch) + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch["observation.image"] = batch[self.input_image_key] batch = self.normalize_targets(batch) diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 058c177e..bc12dfa2 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -98,6 +98,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): """ batch = self.normalize_inputs(batch) + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) # Note: It's important that this happens after stacking the images into a single key. self._queues = populate_queues(self._queues, batch) @@ -123,6 +124,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch = self.normalize_targets(batch) # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181) diff --git a/tests/test_policies.py b/tests/test_policies.py index 63f394e9..d9b946ab 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +from copy import deepcopy from pathlib import Path import einops @@ -161,8 +162,13 @@ def test_policy(env_name, policy_name, extra_overrides): for key in batch: batch[key] = batch[key].to(DEVICE, non_blocking=True) - # Test updating the policy + # Test updating the policy (and test that it does not mutate the batch) + batch_ = deepcopy(batch) policy.forward(batch) + assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass." + assert all( + torch.equal(batch[k], batch_[k]) for k in batch + ), "Batch values are not the same after a forward pass." # reset the policy and environment policy.reset() @@ -174,9 +180,16 @@ def test_policy(env_name, policy_name, extra_overrides): # send observation to device/gpu observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation} - # get the next action for the environment + # get the next action for the environment (also check that the observation batch is not modified) + observation_ = deepcopy(observation) with torch.inference_mode(): action = policy.select_action(observation).cpu().numpy() + assert set(observation) == set( + observation_ + ), "Observation batch keys are not the same after a forward pass." + assert all( + torch.equal(observation[k], observation_[k]) for k in observation + ), "Observation batch values are not the same after a forward pass." # Test step through policy env.step(action)