backup wip
This commit is contained in:
parent
8d5c912515
commit
c643a1ba7f
|
@ -2,6 +2,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
@ -70,20 +71,27 @@ class Logger:
|
||||||
def save_model(self, policy, identifier):
|
def save_model(self, policy, identifier):
|
||||||
if self._save_model:
|
if self._save_model:
|
||||||
self._model_dir.mkdir(parents=True, exist_ok=True)
|
self._model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
fp = self._model_dir / f"{str(identifier)}.pt"
|
|
||||||
# TODO(alexander-soare): This conditional branching is temporary while we add PyTorchModelHubMixin
|
# TODO(alexander-soare): This conditional branching is temporary while we add PyTorchModelHubMixin
|
||||||
# to all policies.
|
# to all policies. Once we're done, we should only use policy.save_pretrained.
|
||||||
if hasattr(policy, "save"):
|
if hasattr(policy, "save"):
|
||||||
policy.save(fp)
|
model_path = self._model_dir / f"{str(identifier)}.pt"
|
||||||
|
policy.save(model_path)
|
||||||
else:
|
else:
|
||||||
policy.save_pretrained(fp)
|
save_dir = self._model_dir / str(identifier)
|
||||||
|
policy.save_pretrained(save_dir)
|
||||||
|
model_path = save_dir / SAFETENSORS_SINGLE_FILE
|
||||||
if self._wandb and not self._disable_wandb_artifact:
|
if self._wandb and not self._disable_wandb_artifact:
|
||||||
# note wandb artifact does not accept ":" in its name
|
# note wandb artifact does not accept ":" in its name
|
||||||
artifact = self._wandb.Artifact(
|
artifact = self._wandb.Artifact(
|
||||||
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
|
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
|
||||||
type="model",
|
type="model",
|
||||||
)
|
)
|
||||||
artifact.add_file(fp)
|
# TODO(alexander-soare). See todo above. This conditional branching is temporary and we should
|
||||||
|
# only have the else path.
|
||||||
|
if hasattr(policy, "save"):
|
||||||
|
artifact.add_file(model_path)
|
||||||
|
else:
|
||||||
|
artifact.add_dir(save_dir)
|
||||||
self._wandb.log_artifact(artifact)
|
self._wandb.log_artifact(artifact)
|
||||||
|
|
||||||
def save_buffer(self, buffer, identifier):
|
def save_buffer(self, buffer, identifier):
|
||||||
|
|
|
@ -24,7 +24,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||||
|
|
||||||
|
|
||||||
class ActPActionChunkingTransformerPolicyolicy(nn.Module, PyTorchModelHubMixin):
|
class ActionChunkingTransformerPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
"""
|
"""
|
||||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||||
|
@ -390,9 +390,6 @@ class ActPActionChunkingTransformerPolicyolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
|
||||||
return actions, (mu, log_sigma_x2)
|
return actions, (mu, log_sigma_x2)
|
||||||
|
|
||||||
def save(self, fp):
|
|
||||||
torch.save(self.state_dict(), fp)
|
|
||||||
|
|
||||||
def load(self, fp):
|
def load(self, fp):
|
||||||
d = torch.load(fp)
|
d = torch.load(fp)
|
||||||
self.load_state_dict(d)
|
self.load_state_dict(d)
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
from lerobot.common.utils import get_safe_torch_device
|
from lerobot.common.utils import get_safe_torch_device
|
||||||
|
|
||||||
|
|
||||||
def make_policy(cfg):
|
def make_policy(cfg: DictConfig):
|
||||||
if cfg.policy.name == "tdmpc":
|
if cfg.policy.name == "tdmpc":
|
||||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,7 @@ import imageio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
|
@ -290,7 +291,7 @@ def eval_policy(
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
def eval(cfg: dict, out_dir=None, stats_path=None):
|
def eval(cfg: DictConfig, out_dir=None, stats_path=None):
|
||||||
if out_dir is None:
|
if out_dir is None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue