backup wip
This commit is contained in:
parent
8d5c912515
commit
c643a1ba7f
|
@ -2,6 +2,7 @@ import logging
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from omegaconf import OmegaConf
|
||||
from termcolor import colored
|
||||
|
||||
|
@ -70,20 +71,27 @@ class Logger:
|
|||
def save_model(self, policy, identifier):
|
||||
if self._save_model:
|
||||
self._model_dir.mkdir(parents=True, exist_ok=True)
|
||||
fp = self._model_dir / f"{str(identifier)}.pt"
|
||||
# TODO(alexander-soare): This conditional branching is temporary while we add PyTorchModelHubMixin
|
||||
# to all policies.
|
||||
# to all policies. Once we're done, we should only use policy.save_pretrained.
|
||||
if hasattr(policy, "save"):
|
||||
policy.save(fp)
|
||||
model_path = self._model_dir / f"{str(identifier)}.pt"
|
||||
policy.save(model_path)
|
||||
else:
|
||||
policy.save_pretrained(fp)
|
||||
save_dir = self._model_dir / str(identifier)
|
||||
policy.save_pretrained(save_dir)
|
||||
model_path = save_dir / SAFETENSORS_SINGLE_FILE
|
||||
if self._wandb and not self._disable_wandb_artifact:
|
||||
# note wandb artifact does not accept ":" in its name
|
||||
artifact = self._wandb.Artifact(
|
||||
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
|
||||
type="model",
|
||||
)
|
||||
artifact.add_file(fp)
|
||||
# TODO(alexander-soare). See todo above. This conditional branching is temporary and we should
|
||||
# only have the else path.
|
||||
if hasattr(policy, "save"):
|
||||
artifact.add_file(model_path)
|
||||
else:
|
||||
artifact.add_dir(save_dir)
|
||||
self._wandb.log_artifact(artifact)
|
||||
|
||||
def save_buffer(self, buffer, identifier):
|
||||
|
|
|
@ -24,7 +24,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
|||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
|
||||
|
||||
class ActPActionChunkingTransformerPolicyolicy(nn.Module, PyTorchModelHubMixin):
|
||||
class ActionChunkingTransformerPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
"""
|
||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||
|
@ -390,9 +390,6 @@ class ActPActionChunkingTransformerPolicyolicy(nn.Module, PyTorchModelHubMixin):
|
|||
|
||||
return actions, (mu, log_sigma_x2)
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import inspect
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from lerobot.common.utils import get_safe_torch_device
|
||||
|
||||
|
||||
def make_policy(cfg):
|
||||
def make_policy(cfg: DictConfig):
|
||||
if cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
|
||||
|
|
|
@ -42,6 +42,7 @@ import imageio
|
|||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.envs.factory import make_env
|
||||
|
@ -290,7 +291,7 @@ def eval_policy(
|
|||
return info
|
||||
|
||||
|
||||
def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
def eval(cfg: DictConfig, out_dir=None, stats_path=None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
Loading…
Reference in New Issue