Merge remote-tracking branch 'origin/main' into user/alexander-soare/multistep_policy_and_serial_env

This commit is contained in:
Alexander Soare 2024-03-15 14:06:53 +00:00
commit bae7e7b41c
2 changed files with 6 additions and 2 deletions

View File

@ -30,6 +30,7 @@ class Logger:
self._model_dir = self._log_dir / "models" self._model_dir = self._log_dir / "models"
self._buffer_dir = self._log_dir / "buffers" self._buffer_dir = self._log_dir / "buffers"
self._save_model = cfg.save_model self._save_model = cfg.save_model
self._disable_wandb_artifact = cfg.wandb.disable_artifact
self._save_buffer = cfg.save_buffer self._save_buffer = cfg.save_buffer
self._group = cfg_to_group(cfg) self._group = cfg_to_group(cfg)
self._seed = cfg.seed self._seed = cfg.seed
@ -71,9 +72,10 @@ class Logger:
self._model_dir.mkdir(parents=True, exist_ok=True) self._model_dir.mkdir(parents=True, exist_ok=True)
fp = self._model_dir / f"{str(identifier)}.pt" fp = self._model_dir / f"{str(identifier)}.pt"
policy.save(fp) policy.save(fp)
if self._wandb: if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" in its name
artifact = self._wandb.Artifact( artifact = self._wandb.Artifact(
self._group + "-" + str(self._seed) + "-" + str(identifier), self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
type="model", type="model",
) )
artifact.add_file(fp) artifact.add_file(fp)

View File

@ -32,5 +32,7 @@ policy: ???
wandb: wandb:
enable: true enable: true
# Set to true to disable saving an artifact despite save_model == True
disable_artifact: false
project: lerobot project: lerobot
notes: "" notes: ""