diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index d2e8b8c9..d3467562 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -55,7 +55,7 @@ while not done: for batch in dataloader: for k in batch: batch[k] = batch[k].to(device, non_blocking=True) - info = policy(batch) + info = policy.update(batch) if step % log_freq == 0: num_samples = (step + 1) * cfg.batch_size loss = info["loss"] diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 18ea3377..567721cd 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -161,6 +161,9 @@ class ActionChunkingTransformerPolicy(nn.Module): if self.cfg.n_action_steps is not None: self._action_queue = deque([], maxlen=self.cfg.n_action_steps) + def forward(self, batch: dict[str, Tensor], **_) -> Tensor: + return self.select_action(self, batch) + @torch.no_grad def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: """Select a single action given environment observations. @@ -172,23 +175,17 @@ class ActionChunkingTransformerPolicy(nn.Module): if len(self._action_queue) == 0: # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has # shape (n_action_steps, batch_size, *), hence the transpose. - self._action_queue.extend(self.select_actions(batch).transpose(0, 1)) + self._action_queue.extend(self._select_actions(batch).transpose(0, 1)) return self._action_queue.popleft() @torch.no_grad - def select_actions(self, batch: dict[str, Tensor]) -> Tensor: + def _select_actions(self, batch: dict[str, Tensor]) -> Tensor: """Use the action chunking transformer to generate a sequence of actions.""" self.eval() self._preprocess_batch(batch, add_obs_steps_dim=True) - action = self.forward(batch, return_loss=False) - return action[: self.cfg.n_action_steps] - def __call__(self, *args, **kwargs) -> dict: - # TODO(alexander-soare): Temporary bridge until we know what to do about the `update` method. - return self.update(*args, **kwargs) - def _preprocess_batch( self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False ) -> dict[str, Tensor]: @@ -216,9 +213,7 @@ class ActionChunkingTransformerPolicy(nn.Module): # Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get # the image index dimension. - def update(self, batch, **_) -> dict: - """Run the model in train mode, compute the loss, and do an optimization step.""" - start_time = time.time() + def compute_loss(self, batch, **_) -> float: self._preprocess_batch(batch) self.train() @@ -230,6 +225,12 @@ class ActionChunkingTransformerPolicy(nn.Module): assert batch_size % num_slices == 0 loss = self.forward(batch, return_loss=True)["loss"] + return loss + + def update(self, batch, **_) -> dict: + """Run the model in train mode, compute the loss, and do an optimization step.""" + start_time = time.time() + loss = self.compute_loss(batch) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 5ff6538d..8e4c1961 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -251,7 +251,7 @@ def train(cfg: dict, out_dir=None, job_name=None): for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - train_info = policy(batch, step=step) + train_info = policy.update(batch, step=step) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: @@ -313,7 +313,7 @@ def train(cfg: dict, out_dir=None, job_name=None): for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - train_info = policy(batch, step) + train_info = policy.update(batch, step) if step % cfg.log_freq == 0: log_train_info(logger, train_info, step, cfg, online_dataset, is_offline) diff --git a/tests/test_policies.py b/tests/test_policies.py index 8ccc7c62..2547a3a2 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -4,11 +4,13 @@ import torch from lerobot.common.datasets.utils import cycle from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.protocol import Policy from lerobot.common.envs.factory import make_env from lerobot.common.datasets.factory import make_dataset from lerobot.common.utils import init_hydra_config from .utils import DEVICE, DEFAULT_CONFIG_PATH + @pytest.mark.parametrize( "env_name,policy_name,extra_overrides", [ @@ -27,6 +29,7 @@ def test_policy(env_name, policy_name, extra_overrides): """ Tests: - Making the policy object. + - Checking that the policy follows the correct protocol. - Updating the policy. - Using the policy to select actions at inference time. - Test the action can be applied to the policy @@ -38,10 +41,14 @@ def test_policy(env_name, policy_name, extra_overrides): f"policy={policy_name}", f"device={DEVICE}", ] - + extra_overrides + + extra_overrides, ) # Check that we can make the policy object. policy = make_policy(cfg) + # Check that the policy follows the required protocol. + assert isinstance( + policy, Policy + ), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}." # Check that we run select_actions and get the appropriate output. dataset = make_dataset(cfg) env = make_env(cfg, num_parallel_envs=2) @@ -62,7 +69,7 @@ def test_policy(env_name, policy_name, extra_overrides): batch[key] = batch[key].to(DEVICE, non_blocking=True) # Test updating the policy - policy(batch, step=0) + policy.update(batch, step=0) # reset the policy and environment policy.reset() @@ -83,4 +90,3 @@ def test_policy(env_name, policy_name, extra_overrides): # Test step through policy env.step(action) -