backup wip

This commit is contained in:
Alexander Soare 2024-04-15 14:18:44 +01:00
parent 8d5c912515
commit c643a1ba7f
4 changed files with 18 additions and 12 deletions

View File

@ -2,6 +2,7 @@ import logging
import os import os
from pathlib import Path from pathlib import Path
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import OmegaConf from omegaconf import OmegaConf
from termcolor import colored from termcolor import colored
@ -70,20 +71,27 @@ class Logger:
def save_model(self, policy, identifier): def save_model(self, policy, identifier):
if self._save_model: if self._save_model:
self._model_dir.mkdir(parents=True, exist_ok=True) self._model_dir.mkdir(parents=True, exist_ok=True)
fp = self._model_dir / f"{str(identifier)}.pt"
# TODO(alexander-soare): This conditional branching is temporary while we add PyTorchModelHubMixin # TODO(alexander-soare): This conditional branching is temporary while we add PyTorchModelHubMixin
# to all policies. # to all policies. Once we're done, we should only use policy.save_pretrained.
if hasattr(policy, "save"): if hasattr(policy, "save"):
policy.save(fp) model_path = self._model_dir / f"{str(identifier)}.pt"
policy.save(model_path)
else: else:
policy.save_pretrained(fp) save_dir = self._model_dir / str(identifier)
policy.save_pretrained(save_dir)
model_path = save_dir / SAFETENSORS_SINGLE_FILE
if self._wandb and not self._disable_wandb_artifact: if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" in its name # note wandb artifact does not accept ":" in its name
artifact = self._wandb.Artifact( artifact = self._wandb.Artifact(
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier), self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
type="model", type="model",
) )
artifact.add_file(fp) # TODO(alexander-soare). See todo above. This conditional branching is temporary and we should
# only have the else path.
if hasattr(policy, "save"):
artifact.add_file(model_path)
else:
artifact.add_dir(save_dir)
self._wandb.log_artifact(artifact) self._wandb.log_artifact(artifact)
def save_buffer(self, buffer, identifier): def save_buffer(self, buffer, identifier):

View File

@ -24,7 +24,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
class ActPActionChunkingTransformerPolicyolicy(nn.Module, PyTorchModelHubMixin): class ActionChunkingTransformerPolicy(nn.Module, PyTorchModelHubMixin):
""" """
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
@ -390,9 +390,6 @@ class ActPActionChunkingTransformerPolicyolicy(nn.Module, PyTorchModelHubMixin):
return actions, (mu, log_sigma_x2) return actions, (mu, log_sigma_x2)
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp): def load(self, fp):
d = torch.load(fp) d = torch.load(fp)
self.load_state_dict(d) self.load_state_dict(d)

View File

@ -1,11 +1,11 @@
import inspect import inspect
from omegaconf import OmegaConf from omegaconf import DictConfig, OmegaConf
from lerobot.common.utils import get_safe_torch_device from lerobot.common.utils import get_safe_torch_device
def make_policy(cfg): def make_policy(cfg: DictConfig):
if cfg.policy.name == "tdmpc": if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy

View File

@ -42,6 +42,7 @@ import imageio
import numpy as np import numpy as np
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from omegaconf import DictConfig
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
@ -290,7 +291,7 @@ def eval_policy(
return info return info
def eval(cfg: dict, out_dir=None, stats_path=None): def eval(cfg: DictConfig, out_dir=None, stats_path=None):
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()