backup wip

This commit is contained in:
Alexander Soare 2024-04-15 14:18:44 +01:00
parent 8d5c912515
commit c643a1ba7f
4 changed files with 18 additions and 12 deletions

View File

@ -2,6 +2,7 @@ import logging
import os
from pathlib import Path
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import OmegaConf
from termcolor import colored
@ -70,20 +71,27 @@ class Logger:
def save_model(self, policy, identifier):
if self._save_model:
self._model_dir.mkdir(parents=True, exist_ok=True)
fp = self._model_dir / f"{str(identifier)}.pt"
# TODO(alexander-soare): This conditional branching is temporary while we add PyTorchModelHubMixin
# to all policies.
# to all policies. Once we're done, we should only use policy.save_pretrained.
if hasattr(policy, "save"):
policy.save(fp)
model_path = self._model_dir / f"{str(identifier)}.pt"
policy.save(model_path)
else:
policy.save_pretrained(fp)
save_dir = self._model_dir / str(identifier)
policy.save_pretrained(save_dir)
model_path = save_dir / SAFETENSORS_SINGLE_FILE
if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" in its name
artifact = self._wandb.Artifact(
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
type="model",
)
artifact.add_file(fp)
# TODO(alexander-soare). See todo above. This conditional branching is temporary and we should
# only have the else path.
if hasattr(policy, "save"):
artifact.add_file(model_path)
else:
artifact.add_dir(save_dir)
self._wandb.log_artifact(artifact)
def save_buffer(self, buffer, identifier):

View File

@ -24,7 +24,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
class ActPActionChunkingTransformerPolicyolicy(nn.Module, PyTorchModelHubMixin):
class ActionChunkingTransformerPolicy(nn.Module, PyTorchModelHubMixin):
"""
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
@ -390,9 +390,6 @@ class ActPActionChunkingTransformerPolicyolicy(nn.Module, PyTorchModelHubMixin):
return actions, (mu, log_sigma_x2)
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)

View File

@ -1,11 +1,11 @@
import inspect
from omegaconf import OmegaConf
from omegaconf import DictConfig, OmegaConf
from lerobot.common.utils import get_safe_torch_device
def make_policy(cfg):
def make_policy(cfg: DictConfig):
if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy

View File

@ -42,6 +42,7 @@ import imageio
import numpy as np
import torch
from huggingface_hub import snapshot_download
from omegaconf import DictConfig
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
@ -290,7 +291,7 @@ def eval_policy(
return info
def eval(cfg: dict, out_dir=None, stats_path=None):
def eval(cfg: DictConfig, out_dir=None, stats_path=None):
if out_dir is None:
raise NotImplementedError()