backup wip
This commit is contained in:
parent
54ec151cbb
commit
77b61e364e
8
Makefile
8
Makefile
|
@ -39,7 +39,7 @@ test-act-ete-train:
|
||||||
eval.n_episodes=1 \
|
eval.n_episodes=1 \
|
||||||
eval.batch_size=1 \
|
eval.batch_size=1 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
training.save_model=true \
|
training.save_checkpoint=true \
|
||||||
training.save_freq=2 \
|
training.save_freq=2 \
|
||||||
policy.n_action_steps=20 \
|
policy.n_action_steps=20 \
|
||||||
policy.chunk_size=20 \
|
policy.chunk_size=20 \
|
||||||
|
@ -65,7 +65,7 @@ test-act-ete-train-amp:
|
||||||
eval.n_episodes=1 \
|
eval.n_episodes=1 \
|
||||||
eval.batch_size=1 \
|
eval.batch_size=1 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
training.save_model=true \
|
training.save_checkpoint=true \
|
||||||
training.save_freq=2 \
|
training.save_freq=2 \
|
||||||
policy.n_action_steps=20 \
|
policy.n_action_steps=20 \
|
||||||
policy.chunk_size=20 \
|
policy.chunk_size=20 \
|
||||||
|
@ -95,7 +95,7 @@ test-diffusion-ete-train:
|
||||||
eval.n_episodes=1 \
|
eval.n_episodes=1 \
|
||||||
eval.batch_size=1 \
|
eval.batch_size=1 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
training.save_model=true \
|
training.save_checkpoint=true \
|
||||||
training.save_freq=2 \
|
training.save_freq=2 \
|
||||||
training.batch_size=2 \
|
training.batch_size=2 \
|
||||||
hydra.run.dir=tests/outputs/diffusion/
|
hydra.run.dir=tests/outputs/diffusion/
|
||||||
|
@ -122,7 +122,7 @@ test-tdmpc-ete-train:
|
||||||
eval.batch_size=1 \
|
eval.batch_size=1 \
|
||||||
env.episode_length=2 \
|
env.episode_length=2 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
training.save_model=true \
|
training.save_checkpoint=true \
|
||||||
training.save_freq=2 \
|
training.save_freq=2 \
|
||||||
training.batch_size=2 \
|
training.batch_size=2 \
|
||||||
hydra.run.dir=tests/outputs/tdmpc/
|
hydra.run.dir=tests/outputs/tdmpc/
|
||||||
|
|
|
@ -13,8 +13,10 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py
|
||||||
|
|
||||||
# TODO(rcadene, alexander-soare): clean this file
|
# TODO(rcadene, alexander-soare): clean this file
|
||||||
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
@ -33,10 +35,6 @@ from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||||
|
|
||||||
|
|
||||||
def log_output_dir(out_dir):
|
|
||||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
|
||||||
|
|
||||||
|
|
||||||
def cfg_to_group(cfg, return_list=False):
|
def cfg_to_group(cfg, return_list=False):
|
||||||
"""Return a group name for logging. Optionally returns group name as list."""
|
"""Return a group name for logging. Optionally returns group name as list."""
|
||||||
lst = [
|
lst = [
|
||||||
|
@ -57,13 +55,12 @@ class Logger:
|
||||||
self._job_name = job_name
|
self._job_name = job_name
|
||||||
self._checkpoint_dir = self._log_dir / "checkpoints"
|
self._checkpoint_dir = self._log_dir / "checkpoints"
|
||||||
self._last_checkpoint_path = self._checkpoint_dir / "last"
|
self._last_checkpoint_path = self._checkpoint_dir / "last"
|
||||||
self._buffer_dir = self._log_dir / "buffers"
|
|
||||||
self._save_model = cfg.training.save_model
|
|
||||||
self._disable_wandb_artifact = cfg.wandb.disable_artifact
|
self._disable_wandb_artifact = cfg.wandb.disable_artifact
|
||||||
self._group = cfg_to_group(cfg)
|
self._group = cfg_to_group(cfg)
|
||||||
self._seed = cfg.seed
|
self._seed = cfg.seed
|
||||||
self._cfg = cfg
|
self._cfg = cfg
|
||||||
self._eval = []
|
|
||||||
|
# Set up WandB.
|
||||||
project = cfg.get("wandb", {}).get("project")
|
project = cfg.get("wandb", {}).get("project")
|
||||||
entity = cfg.get("wandb", {}).get("entity")
|
entity = cfg.get("wandb", {}).get("entity")
|
||||||
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
||||||
|
@ -112,20 +109,19 @@ class Logger:
|
||||||
return self._last_checkpoint_path
|
return self._last_checkpoint_path
|
||||||
|
|
||||||
def save_model(self, policy: Policy, identifier: str):
|
def save_model(self, policy: Policy, identifier: str):
|
||||||
if self._save_model:
|
self._checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||||
self._checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
save_dir = self._checkpoint_dir / str(identifier)
|
||||||
save_dir = self._checkpoint_dir / str(identifier)
|
policy.save_pretrained(save_dir)
|
||||||
policy.save_pretrained(save_dir)
|
# Also save the full Hydra config for the env configuration.
|
||||||
# Also save the full Hydra config for the env configuration.
|
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
||||||
OmegaConf.save(self._cfg, save_dir / "config.yaml")
|
if self._wandb and not self._disable_wandb_artifact:
|
||||||
if self._wandb and not self._disable_wandb_artifact:
|
# note wandb artifact does not accept ":" or "/" in its name
|
||||||
# note wandb artifact does not accept ":" or "/" in its name
|
artifact = self._wandb.Artifact(
|
||||||
artifact = self._wandb.Artifact(
|
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
|
||||||
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
|
type="model",
|
||||||
type="model",
|
)
|
||||||
)
|
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
self._wandb.log_artifact(artifact)
|
||||||
self._wandb.log_artifact(artifact)
|
|
||||||
if self._last_checkpoint_path.exists():
|
if self._last_checkpoint_path.exists():
|
||||||
os.remove(self._last_checkpoint_path)
|
os.remove(self._last_checkpoint_path)
|
||||||
os.symlink(save_dir.absolute(), self._last_checkpoint_path) # TODO(now): Check this works
|
os.symlink(save_dir.absolute(), self._last_checkpoint_path) # TODO(now): Check this works
|
||||||
|
|
|
@ -38,7 +38,7 @@ training:
|
||||||
eval_freq: ???
|
eval_freq: ???
|
||||||
save_freq: ???
|
save_freq: ???
|
||||||
log_freq: 250
|
log_freq: 250
|
||||||
save_model: true
|
save_checkpoint: true
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
n_episodes: 1
|
n_episodes: 1
|
||||||
|
@ -49,7 +49,7 @@ eval:
|
||||||
|
|
||||||
wandb:
|
wandb:
|
||||||
enable: false
|
enable: false
|
||||||
# Set to true to disable saving an artifact despite save_model == True
|
# Set to true to disable saving an artifact despite save_checkpoint == True
|
||||||
disable_artifact: false
|
disable_artifact: false
|
||||||
project: lerobot
|
project: lerobot
|
||||||
notes: ""
|
notes: ""
|
||||||
|
|
|
@ -15,7 +15,7 @@ training:
|
||||||
eval_freq: 10000
|
eval_freq: 10000
|
||||||
save_freq: 100000
|
save_freq: 100000
|
||||||
log_freq: 250
|
log_freq: 250
|
||||||
save_model: true
|
save_checkpoint: true
|
||||||
|
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
lr: 1e-5
|
lr: 1e-5
|
||||||
|
|
|
@ -27,7 +27,7 @@ training:
|
||||||
eval_freq: 5000
|
eval_freq: 5000
|
||||||
save_freq: 5000
|
save_freq: 5000
|
||||||
log_freq: 250
|
log_freq: 250
|
||||||
save_model: true
|
save_checkpoint: true
|
||||||
|
|
||||||
batch_size: 64
|
batch_size: 64
|
||||||
grad_clip_norm: 10
|
grad_clip_norm: 10
|
||||||
|
|
|
@ -262,7 +262,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
f"You have set resume=True, but {str(logger.last_checkpoint_path)} does not exist."
|
f"You have set resume=True, but {str(logger.last_checkpoint_path)} does not exist."
|
||||||
)
|
)
|
||||||
# Get the configuration file from the last checkpoint.
|
# Get the configuration file from the last checkpoint.
|
||||||
checkpoint_cfg = init_hydra_config(str(logger.last_checkpoint_path))
|
checkpoint_cfg = init_hydra_config(str(logger.last_checkpoint_path / "config.yaml"))
|
||||||
# TODO(now): Do a diff check.
|
# TODO(now): Do a diff check.
|
||||||
cfg = checkpoint_cfg
|
cfg = checkpoint_cfg
|
||||||
step = logger.load_last_training_state(optimizer, lr_scheduler)
|
step = logger.load_last_training_state(optimizer, lr_scheduler)
|
||||||
|
@ -297,7 +297,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
|
||||||
if cfg.training.save_model and step % cfg.training.save_freq == 0:
|
if cfg.training.save_checkpoint and step % cfg.training.save_freq == 0:
|
||||||
logging.info(f"Checkpoint policy after step {step}")
|
logging.info(f"Checkpoint policy after step {step}")
|
||||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||||
|
|
Loading…
Reference in New Issue