revision 1

This commit is contained in:
Alexander Soare 2024-04-16 17:15:51 +01:00
parent 43a614c173
commit a9496fde39
4 changed files with 21 additions and 15 deletions

View File

@ -53,16 +53,10 @@ step = 0
done = False done = False
while not done: while not done:
for batch in dataloader: for batch in dataloader:
for k in batch: batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
batch[k] = batch[k].to(device, non_blocking=True)
info = policy(batch) info = policy(batch)
if step % log_freq == 0: if step % log_freq == 0:
num_samples = (step + 1) * cfg.batch_size print(f"step: {step} loss: {info['loss']:.3f} update_time: {info['update_s']:.3f} (seconds)")
loss = info["loss"]
update_s = info["update_s"]
print(
f"step: {step} samples: {num_samples} loss: {loss:.3f} update_time: {update_s:.3f} (seconds)"
)
step += 1 step += 1
if step >= training_steps: if step >= training_steps:
done = True done = True

View File

@ -65,12 +65,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
"ActionChunkingTransformerPolicy does not handle multiple observation steps." "ActionChunkingTransformerPolicy does not handle multiple observation steps."
) )
def __init__(self, cfg: ActionChunkingTransformerConfig): def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
""" """
TODO(alexander-soare): Add documentation for all parameters once we have model configs established. Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
""" """
super().__init__() super().__init__()
if getattr(cfg, "n_obs_steps", 1) != 1: if cfg is None:
cfg = ActionChunkingTransformerConfig()
if cfg.n_obs_steps != 1:
raise ValueError(self._multiple_obs_steps_not_handled_msg) raise ValueError(self._multiple_obs_steps_not_handled_msg)
self.cfg = cfg self.cfg = cfg

View File

@ -33,8 +33,6 @@ from lerobot.common.policies.utils import (
populate_queues, populate_queues,
) )
logger = logging.getLogger(__name__)
class DiffusionPolicy(nn.Module): class DiffusionPolicy(nn.Module):
""" """
@ -44,8 +42,17 @@ class DiffusionPolicy(nn.Module):
name = "diffusion" name = "diffusion"
def __init__(self, cfg: DiffusionConfig, lr_scheduler_num_training_steps: int): def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
super().__init__() super().__init__()
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
"""
# TODO(alexander-soare): LR scheduler will be removed.
assert lr_scheduler_num_training_steps > 0
if cfg is None:
cfg = DiffusionConfig()
self.cfg = cfg self.cfg = cfg
# queues are populated during rollout of the policy, they contain the n latest observations and actions # queues are populated during rollout of the policy, they contain the n latest observations and actions

View File

@ -40,7 +40,8 @@ def test_examples_3_and_2():
], ],
) )
exec(file_contents) # Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
exec(file_contents, {})
for file_name in ["model.pt", "stats.pth", "config.yaml"]: for file_name in ["model.pt", "stats.pth", "config.yaml"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()