From c362dc94e4f1040dc28e489de1f01970ff18211c Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 15 Apr 2024 10:33:00 +0100 Subject: [PATCH] backup wip --- lerobot/common/logger.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index f52107ed..3cda4430 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -71,7 +71,12 @@ class Logger: if self._save_model: self._model_dir.mkdir(parents=True, exist_ok=True) fp = self._model_dir / f"{str(identifier)}.pt" - policy.save_pretrained(fp) + # TODO(alexander-soare): This conditional branching is temporary while we add PyTorchModelHubMixin + # to all policies. + if hasattr(policy, "save"): + policy.save(fp) + else: + policy.save_pretrained(fp) if self._wandb and not self._disable_wandb_artifact: # note wandb artifact does not accept ":" in its name artifact = self._wandb.Artifact(