From c015252e205d6539e693247d132a9bec46c6d9fc Mon Sep 17 00:00:00 2001 From: Simon Alibert <75076266+aliberts@users.noreply.github.com> Date: Sat, 4 May 2024 15:33:55 +0200 Subject: [PATCH] Remove batch length restrictions in select_action (#123) --- lerobot/common/policies/act/modeling_act.py | 3 +++ lerobot/common/policies/diffusion/modeling_diffusion.py | 1 - lerobot/common/policies/tdmpc/modeling_tdmpc.py | 1 - 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index f9e52e02..5ff25fea 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -71,6 +71,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): environment. It works by managing the actions in a queue and only calling `select_actions` when the queue is empty. """ + assert "observation.images.top" in batch + assert "observation.state" in batch + self.eval() batch = self.normalize_inputs(batch) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 5b6da771..c639e2f9 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -115,7 +115,6 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): """ assert "observation.image" in batch assert "observation.state" in batch - assert len(batch) == 2 batch = self.normalize_inputs(batch) diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 4205b4fc..eab0f94e 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -122,7 +122,6 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): """Select a single action given environment observations.""" assert "observation.image" in batch assert "observation.state" in batch - assert len(batch) == 2 batch = self.normalize_inputs(batch)