From 77b61e364eda2865e5c9e031cd8cd70612ff73f3 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 21 May 2024 09:01:34 +0100 Subject: [PATCH] backup wip --- Makefile | 8 +++--- lerobot/common/logger.py | 40 ++++++++++++--------------- lerobot/configs/default.yaml | 4 +-- lerobot/configs/policy/act.yaml | 2 +- lerobot/configs/policy/diffusion.yaml | 2 +- lerobot/scripts/train.py | 4 +-- 6 files changed, 28 insertions(+), 32 deletions(-) diff --git a/Makefile b/Makefile index 9a8a2474..7a5b1e0e 100644 --- a/Makefile +++ b/Makefile @@ -39,7 +39,7 @@ test-act-ete-train: eval.n_episodes=1 \ eval.batch_size=1 \ device=cpu \ - training.save_model=true \ + training.save_checkpoint=true \ training.save_freq=2 \ policy.n_action_steps=20 \ policy.chunk_size=20 \ @@ -65,7 +65,7 @@ test-act-ete-train-amp: eval.n_episodes=1 \ eval.batch_size=1 \ device=cpu \ - training.save_model=true \ + training.save_checkpoint=true \ training.save_freq=2 \ policy.n_action_steps=20 \ policy.chunk_size=20 \ @@ -95,7 +95,7 @@ test-diffusion-ete-train: eval.n_episodes=1 \ eval.batch_size=1 \ device=cpu \ - training.save_model=true \ + training.save_checkpoint=true \ training.save_freq=2 \ training.batch_size=2 \ hydra.run.dir=tests/outputs/diffusion/ @@ -122,7 +122,7 @@ test-tdmpc-ete-train: eval.batch_size=1 \ env.episode_length=2 \ device=cpu \ - training.save_model=true \ + training.save_checkpoint=true \ training.save_freq=2 \ training.batch_size=2 \ hydra.run.dir=tests/outputs/tdmpc/ diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index d8398ff1..fe17e1d7 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -13,8 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py + # TODO(rcadene, alexander-soare): clean this file -"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py""" +""" import logging import os @@ -33,10 +35,6 @@ from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import get_global_random_state, set_global_random_state -def log_output_dir(out_dir): - logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}") - - def cfg_to_group(cfg, return_list=False): """Return a group name for logging. Optionally returns group name as list.""" lst = [ @@ -57,13 +55,12 @@ class Logger: self._job_name = job_name self._checkpoint_dir = self._log_dir / "checkpoints" self._last_checkpoint_path = self._checkpoint_dir / "last" - self._buffer_dir = self._log_dir / "buffers" - self._save_model = cfg.training.save_model self._disable_wandb_artifact = cfg.wandb.disable_artifact self._group = cfg_to_group(cfg) self._seed = cfg.seed self._cfg = cfg - self._eval = [] + + # Set up WandB. project = cfg.get("wandb", {}).get("project") entity = cfg.get("wandb", {}).get("entity") enable_wandb = cfg.get("wandb", {}).get("enable", False) @@ -112,20 +109,19 @@ class Logger: return self._last_checkpoint_path def save_model(self, policy: Policy, identifier: str): - if self._save_model: - self._checkpoint_dir.mkdir(parents=True, exist_ok=True) - save_dir = self._checkpoint_dir / str(identifier) - policy.save_pretrained(save_dir) - # Also save the full Hydra config for the env configuration. - OmegaConf.save(self._cfg, save_dir / "config.yaml") - if self._wandb and not self._disable_wandb_artifact: - # note wandb artifact does not accept ":" or "/" in its name - artifact = self._wandb.Artifact( - f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}", - type="model", - ) - artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE) - self._wandb.log_artifact(artifact) + self._checkpoint_dir.mkdir(parents=True, exist_ok=True) + save_dir = self._checkpoint_dir / str(identifier) + policy.save_pretrained(save_dir) + # Also save the full Hydra config for the env configuration. + OmegaConf.save(self._cfg, save_dir / "config.yaml") + if self._wandb and not self._disable_wandb_artifact: + # note wandb artifact does not accept ":" or "/" in its name + artifact = self._wandb.Artifact( + f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}", + type="model", + ) + artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE) + self._wandb.log_artifact(artifact) if self._last_checkpoint_path.exists(): os.remove(self._last_checkpoint_path) os.symlink(save_dir.absolute(), self._last_checkpoint_path) # TODO(now): Check this works diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 42f2a92c..9b320059 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -38,7 +38,7 @@ training: eval_freq: ??? save_freq: ??? log_freq: 250 - save_model: true + save_checkpoint: true eval: n_episodes: 1 @@ -49,7 +49,7 @@ eval: wandb: enable: false - # Set to true to disable saving an artifact despite save_model == True + # Set to true to disable saving an artifact despite save_checkpoint == True disable_artifact: false project: lerobot notes: "" diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index 7a12dcc2..bba2e563 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -15,7 +15,7 @@ training: eval_freq: 10000 save_freq: 100000 log_freq: 250 - save_model: true + save_checkpoint: true batch_size: 8 lr: 1e-5 diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 7278985e..36bd22cc 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -27,7 +27,7 @@ training: eval_freq: 5000 save_freq: 5000 log_freq: 250 - save_model: true + save_checkpoint: true batch_size: 64 grad_clip_norm: 10 diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index dc4ec1f6..1874f4d6 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -262,7 +262,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No f"You have set resume=True, but {str(logger.last_checkpoint_path)} does not exist." ) # Get the configuration file from the last checkpoint. - checkpoint_cfg = init_hydra_config(str(logger.last_checkpoint_path)) + checkpoint_cfg = init_hydra_config(str(logger.last_checkpoint_path / "config.yaml")) # TODO(now): Do a diff check. cfg = checkpoint_cfg step = logger.load_last_training_state(optimizer, lr_scheduler) @@ -297,7 +297,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logger.log_video(eval_info["video_paths"][0], step, mode="eval") logging.info("Resume training") - if cfg.training.save_model and step % cfg.training.save_freq == 0: + if cfg.training.save_checkpoint and step % cfg.training.save_freq == 0: logging.info(f"Checkpoint policy after step {step}") # Note: Save with step as the identifier, and format it to have at least 6 digits but more if # needed (choose 6 as a minimum for consistency without being overkill).