backup wip

This commit is contained in:
Alexander Soare 2024-05-21 09:01:34 +01:00
parent 54ec151cbb
commit 77b61e364e
6 changed files with 28 additions and 32 deletions

View File

@ -39,7 +39,7 @@ test-act-ete-train:
eval.n_episodes=1 \
eval.batch_size=1 \
device=cpu \
training.save_model=true \
training.save_checkpoint=true \
training.save_freq=2 \
policy.n_action_steps=20 \
policy.chunk_size=20 \
@ -65,7 +65,7 @@ test-act-ete-train-amp:
eval.n_episodes=1 \
eval.batch_size=1 \
device=cpu \
training.save_model=true \
training.save_checkpoint=true \
training.save_freq=2 \
policy.n_action_steps=20 \
policy.chunk_size=20 \
@ -95,7 +95,7 @@ test-diffusion-ete-train:
eval.n_episodes=1 \
eval.batch_size=1 \
device=cpu \
training.save_model=true \
training.save_checkpoint=true \
training.save_freq=2 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/diffusion/
@ -122,7 +122,7 @@ test-tdmpc-ete-train:
eval.batch_size=1 \
env.episode_length=2 \
device=cpu \
training.save_model=true \
training.save_checkpoint=true \
training.save_freq=2 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/tdmpc/

View File

@ -13,8 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py
# TODO(rcadene, alexander-soare): clean this file
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py"""
"""
import logging
import os
@ -33,10 +35,6 @@ from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
def cfg_to_group(cfg, return_list=False):
"""Return a group name for logging. Optionally returns group name as list."""
lst = [
@ -57,13 +55,12 @@ class Logger:
self._job_name = job_name
self._checkpoint_dir = self._log_dir / "checkpoints"
self._last_checkpoint_path = self._checkpoint_dir / "last"
self._buffer_dir = self._log_dir / "buffers"
self._save_model = cfg.training.save_model
self._disable_wandb_artifact = cfg.wandb.disable_artifact
self._group = cfg_to_group(cfg)
self._seed = cfg.seed
self._cfg = cfg
self._eval = []
# Set up WandB.
project = cfg.get("wandb", {}).get("project")
entity = cfg.get("wandb", {}).get("entity")
enable_wandb = cfg.get("wandb", {}).get("enable", False)
@ -112,20 +109,19 @@ class Logger:
return self._last_checkpoint_path
def save_model(self, policy: Policy, identifier: str):
if self._save_model:
self._checkpoint_dir.mkdir(parents=True, exist_ok=True)
save_dir = self._checkpoint_dir / str(identifier)
policy.save_pretrained(save_dir)
# Also save the full Hydra config for the env configuration.
OmegaConf.save(self._cfg, save_dir / "config.yaml")
if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact(
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
type="model",
)
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
self._checkpoint_dir.mkdir(parents=True, exist_ok=True)
save_dir = self._checkpoint_dir / str(identifier)
policy.save_pretrained(save_dir)
# Also save the full Hydra config for the env configuration.
OmegaConf.save(self._cfg, save_dir / "config.yaml")
if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact(
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
type="model",
)
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
if self._last_checkpoint_path.exists():
os.remove(self._last_checkpoint_path)
os.symlink(save_dir.absolute(), self._last_checkpoint_path) # TODO(now): Check this works

View File

@ -38,7 +38,7 @@ training:
eval_freq: ???
save_freq: ???
log_freq: 250
save_model: true
save_checkpoint: true
eval:
n_episodes: 1
@ -49,7 +49,7 @@ eval:
wandb:
enable: false
# Set to true to disable saving an artifact despite save_model == True
# Set to true to disable saving an artifact despite save_checkpoint == True
disable_artifact: false
project: lerobot
notes: ""

View File

@ -15,7 +15,7 @@ training:
eval_freq: 10000
save_freq: 100000
log_freq: 250
save_model: true
save_checkpoint: true
batch_size: 8
lr: 1e-5

View File

@ -27,7 +27,7 @@ training:
eval_freq: 5000
save_freq: 5000
log_freq: 250
save_model: true
save_checkpoint: true
batch_size: 64
grad_clip_norm: 10

View File

@ -262,7 +262,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
f"You have set resume=True, but {str(logger.last_checkpoint_path)} does not exist."
)
# Get the configuration file from the last checkpoint.
checkpoint_cfg = init_hydra_config(str(logger.last_checkpoint_path))
checkpoint_cfg = init_hydra_config(str(logger.last_checkpoint_path / "config.yaml"))
# TODO(now): Do a diff check.
cfg = checkpoint_cfg
step = logger.load_last_training_state(optimizer, lr_scheduler)
@ -297,7 +297,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
logging.info("Resume training")
if cfg.training.save_model and step % cfg.training.save_freq == 0:
if cfg.training.save_checkpoint and step % cfg.training.save_freq == 0:
logging.info(f"Checkpoint policy after step {step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill).