backup wip
This commit is contained in:
parent
54ec151cbb
commit
77b61e364e
8
Makefile
8
Makefile
|
@ -39,7 +39,7 @@ test-act-ete-train:
|
|||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=cpu \
|
||||
training.save_model=true \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
policy.n_action_steps=20 \
|
||||
policy.chunk_size=20 \
|
||||
|
@ -65,7 +65,7 @@ test-act-ete-train-amp:
|
|||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=cpu \
|
||||
training.save_model=true \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
policy.n_action_steps=20 \
|
||||
policy.chunk_size=20 \
|
||||
|
@ -95,7 +95,7 @@ test-diffusion-ete-train:
|
|||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=cpu \
|
||||
training.save_model=true \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/diffusion/
|
||||
|
@ -122,7 +122,7 @@ test-tdmpc-ete-train:
|
|||
eval.batch_size=1 \
|
||||
env.episode_length=2 \
|
||||
device=cpu \
|
||||
training.save_model=true \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/tdmpc/
|
||||
|
|
|
@ -13,8 +13,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py
|
||||
|
||||
# TODO(rcadene, alexander-soare): clean this file
|
||||
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py"""
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
@ -33,10 +35,6 @@ from lerobot.common.policies.policy_protocol import Policy
|
|||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
|
||||
|
||||
def cfg_to_group(cfg, return_list=False):
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
|
@ -57,13 +55,12 @@ class Logger:
|
|||
self._job_name = job_name
|
||||
self._checkpoint_dir = self._log_dir / "checkpoints"
|
||||
self._last_checkpoint_path = self._checkpoint_dir / "last"
|
||||
self._buffer_dir = self._log_dir / "buffers"
|
||||
self._save_model = cfg.training.save_model
|
||||
self._disable_wandb_artifact = cfg.wandb.disable_artifact
|
||||
self._group = cfg_to_group(cfg)
|
||||
self._seed = cfg.seed
|
||||
self._cfg = cfg
|
||||
self._eval = []
|
||||
|
||||
# Set up WandB.
|
||||
project = cfg.get("wandb", {}).get("project")
|
||||
entity = cfg.get("wandb", {}).get("entity")
|
||||
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
||||
|
@ -112,20 +109,19 @@ class Logger:
|
|||
return self._last_checkpoint_path
|
||||
|
||||
def save_model(self, policy: Policy, identifier: str):
|
||||
if self._save_model:
|
||||
self._checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_dir = self._checkpoint_dir / str(identifier)
|
||||
policy.save_pretrained(save_dir)
|
||||
# Also save the full Hydra config for the env configuration.
|
||||
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
||||
if self._wandb and not self._disable_wandb_artifact:
|
||||
# note wandb artifact does not accept ":" or "/" in its name
|
||||
artifact = self._wandb.Artifact(
|
||||
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
|
||||
type="model",
|
||||
)
|
||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
self._checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_dir = self._checkpoint_dir / str(identifier)
|
||||
policy.save_pretrained(save_dir)
|
||||
# Also save the full Hydra config for the env configuration.
|
||||
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
||||
if self._wandb and not self._disable_wandb_artifact:
|
||||
# note wandb artifact does not accept ":" or "/" in its name
|
||||
artifact = self._wandb.Artifact(
|
||||
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
|
||||
type="model",
|
||||
)
|
||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
if self._last_checkpoint_path.exists():
|
||||
os.remove(self._last_checkpoint_path)
|
||||
os.symlink(save_dir.absolute(), self._last_checkpoint_path) # TODO(now): Check this works
|
||||
|
|
|
@ -38,7 +38,7 @@ training:
|
|||
eval_freq: ???
|
||||
save_freq: ???
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
save_checkpoint: true
|
||||
|
||||
eval:
|
||||
n_episodes: 1
|
||||
|
@ -49,7 +49,7 @@ eval:
|
|||
|
||||
wandb:
|
||||
enable: false
|
||||
# Set to true to disable saving an artifact despite save_model == True
|
||||
# Set to true to disable saving an artifact despite save_checkpoint == True
|
||||
disable_artifact: false
|
||||
project: lerobot
|
||||
notes: ""
|
||||
|
|
|
@ -15,7 +15,7 @@ training:
|
|||
eval_freq: 10000
|
||||
save_freq: 100000
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
|
|
|
@ -27,7 +27,7 @@ training:
|
|||
eval_freq: 5000
|
||||
save_freq: 5000
|
||||
log_freq: 250
|
||||
save_model: true
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 64
|
||||
grad_clip_norm: 10
|
||||
|
|
|
@ -262,7 +262,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
f"You have set resume=True, but {str(logger.last_checkpoint_path)} does not exist."
|
||||
)
|
||||
# Get the configuration file from the last checkpoint.
|
||||
checkpoint_cfg = init_hydra_config(str(logger.last_checkpoint_path))
|
||||
checkpoint_cfg = init_hydra_config(str(logger.last_checkpoint_path / "config.yaml"))
|
||||
# TODO(now): Do a diff check.
|
||||
cfg = checkpoint_cfg
|
||||
step = logger.load_last_training_state(optimizer, lr_scheduler)
|
||||
|
@ -297,7 +297,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if cfg.training.save_model and step % cfg.training.save_freq == 0:
|
||||
if cfg.training.save_checkpoint and step % cfg.training.save_freq == 0:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||
|
|
Loading…
Reference in New Issue