diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index f52107ed..3cda4430 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -71,7 +71,12 @@ class Logger: if self._save_model: self._model_dir.mkdir(parents=True, exist_ok=True) fp = self._model_dir / f"{str(identifier)}.pt" - policy.save_pretrained(fp) + # TODO(alexander-soare): This conditional branching is temporary while we add PyTorchModelHubMixin + # to all policies. + if hasattr(policy, "save"): + policy.save(fp) + else: + policy.save_pretrained(fp) if self._wandb and not self._disable_wandb_artifact: # note wandb artifact does not accept ":" in its name artifact = self._wandb.Artifact(