From a9496fde3976a29ac41674a5c1e826245dd42f22 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 16 Apr 2024 17:15:51 +0100 Subject: [PATCH] revision 1 --- examples/3_train_policy.py | 10 ++-------- lerobot/common/policies/act/modeling_act.py | 10 +++++++--- .../common/policies/diffusion/modeling_diffusion.py | 13 ++++++++++--- tests/test_examples.py | 3 ++- 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index d2e8b8c9..0c8decc4 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -53,16 +53,10 @@ step = 0 done = False while not done: for batch in dataloader: - for k in batch: - batch[k] = batch[k].to(device, non_blocking=True) + batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} info = policy(batch) if step % log_freq == 0: - num_samples = (step + 1) * cfg.batch_size - loss = info["loss"] - update_s = info["update_s"] - print( - f"step: {step} samples: {num_samples} loss: {loss:.3f} update_time: {update_s:.3f} (seconds)" - ) + print(f"step: {step} loss: {info['loss']:.3f} update_time: {info['update_s']:.3f} (seconds)") step += 1 if step >= training_steps: done = True diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 18ea3377..5f2429a6 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -65,12 +65,16 @@ class ActionChunkingTransformerPolicy(nn.Module): "ActionChunkingTransformerPolicy does not handle multiple observation steps." ) - def __init__(self, cfg: ActionChunkingTransformerConfig): + def __init__(self, cfg: ActionChunkingTransformerConfig | None = None): """ - TODO(alexander-soare): Add documentation for all parameters once we have model configs established. + Args: + cfg: Policy configuration class instance or None, in which case the default instantiation of the + configuration class is used. """ super().__init__() - if getattr(cfg, "n_obs_steps", 1) != 1: + if cfg is None: + cfg = ActionChunkingTransformerConfig() + if cfg.n_obs_steps != 1: raise ValueError(self._multiple_obs_steps_not_handled_msg) self.cfg = cfg diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 9a02c6a2..dfab9bb7 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -33,8 +33,6 @@ from lerobot.common.policies.utils import ( populate_queues, ) -logger = logging.getLogger(__name__) - class DiffusionPolicy(nn.Module): """ @@ -44,8 +42,17 @@ class DiffusionPolicy(nn.Module): name = "diffusion" - def __init__(self, cfg: DiffusionConfig, lr_scheduler_num_training_steps: int): + def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0): super().__init__() + """ + Args: + cfg: Policy configuration class instance or None, in which case the default instantiation of the + configuration class is used. + """ + # TODO(alexander-soare): LR scheduler will be removed. + assert lr_scheduler_num_training_steps > 0 + if cfg is None: + cfg = DiffusionConfig() self.cfg = cfg # queues are populated during rollout of the policy, they contain the n latest observations and actions diff --git a/tests/test_examples.py b/tests/test_examples.py index 6cab7a1a..c510eb1e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -40,7 +40,8 @@ def test_examples_3_and_2(): ], ) - exec(file_contents) + # Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249. + exec(file_contents, {}) for file_name in ["model.pt", "stats.pth", "config.yaml"]: assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()