Merge remote-tracking branch 'origin/main' into user/alexander-soare/multistep_policy_and_serial_env

This commit is contained in:
Alexander Soare 2024-03-15 14:06:53 +00:00
commit bae7e7b41c
2 changed files with 6 additions and 2 deletions

View File

@ -30,6 +30,7 @@ class Logger:
self._model_dir = self._log_dir / "models"
self._buffer_dir = self._log_dir / "buffers"
self._save_model = cfg.save_model
self._disable_wandb_artifact = cfg.wandb.disable_artifact
self._save_buffer = cfg.save_buffer
self._group = cfg_to_group(cfg)
self._seed = cfg.seed
@ -71,9 +72,10 @@ class Logger:
self._model_dir.mkdir(parents=True, exist_ok=True)
fp = self._model_dir / f"{str(identifier)}.pt"
policy.save(fp)
if self._wandb:
if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" in its name
artifact = self._wandb.Artifact(
self._group + "-" + str(self._seed) + "-" + str(identifier),
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
type="model",
)
artifact.add_file(fp)

View File

@ -32,5 +32,7 @@ policy: ???
wandb:
enable: true
# Set to true to disable saving an artifact despite save_model == True
disable_artifact: false
project: lerobot
notes: ""