From c643a1ba7f6a7156d4463d572cff7c9f04747b76 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 15 Apr 2024 14:18:44 +0100 Subject: [PATCH] backup wip --- lerobot/common/logger.py | 18 +++++++++++++----- lerobot/common/policies/act/modeling_act.py | 5 +---- lerobot/common/policies/factory.py | 4 ++-- lerobot/scripts/eval.py | 3 ++- 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 3cda4430..059b6731 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -2,6 +2,7 @@ import logging import os from pathlib import Path +from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from omegaconf import OmegaConf from termcolor import colored @@ -70,20 +71,27 @@ class Logger: def save_model(self, policy, identifier): if self._save_model: self._model_dir.mkdir(parents=True, exist_ok=True) - fp = self._model_dir / f"{str(identifier)}.pt" # TODO(alexander-soare): This conditional branching is temporary while we add PyTorchModelHubMixin - # to all policies. + # to all policies. Once we're done, we should only use policy.save_pretrained. if hasattr(policy, "save"): - policy.save(fp) + model_path = self._model_dir / f"{str(identifier)}.pt" + policy.save(model_path) else: - policy.save_pretrained(fp) + save_dir = self._model_dir / str(identifier) + policy.save_pretrained(save_dir) + model_path = save_dir / SAFETENSORS_SINGLE_FILE if self._wandb and not self._disable_wandb_artifact: # note wandb artifact does not accept ":" in its name artifact = self._wandb.Artifact( self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier), type="model", ) - artifact.add_file(fp) + # TODO(alexander-soare). See todo above. This conditional branching is temporary and we should + # only have the else path. + if hasattr(policy, "save"): + artifact.add_file(model_path) + else: + artifact.add_dir(save_dir) self._wandb.log_artifact(artifact) def save_buffer(self, buffer, identifier): diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index a8513bce..85b0ce3f 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -24,7 +24,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig -class ActPActionChunkingTransformerPolicyolicy(nn.Module, PyTorchModelHubMixin): +class ActionChunkingTransformerPolicy(nn.Module, PyTorchModelHubMixin): """ Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) @@ -390,9 +390,6 @@ class ActPActionChunkingTransformerPolicyolicy(nn.Module, PyTorchModelHubMixin): return actions, (mu, log_sigma_x2) - def save(self, fp): - torch.save(self.state_dict(), fp) - def load(self, fp): d = torch.load(fp) self.load_state_dict(d) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 80ae27da..e021adf6 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,11 +1,11 @@ import inspect -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf from lerobot.common.utils import get_safe_torch_device -def make_policy(cfg): +def make_policy(cfg: DictConfig): if cfg.policy.name == "tdmpc": from lerobot.common.policies.tdmpc.policy import TDMPCPolicy diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 2b8906d7..9e1b8637 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -42,6 +42,7 @@ import imageio import numpy as np import torch from huggingface_hub import snapshot_download +from omegaconf import DictConfig from lerobot.common.datasets.factory import make_dataset from lerobot.common.envs.factory import make_env @@ -290,7 +291,7 @@ def eval_policy( return info -def eval(cfg: dict, out_dir=None, stats_path=None): +def eval(cfg: DictConfig, out_dir=None, stats_path=None): if out_dir is None: raise NotImplementedError()