revision 1
This commit is contained in:
parent
43a614c173
commit
a9496fde39
|
@ -53,16 +53,10 @@ step = 0
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
for k in batch:
|
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||||
batch[k] = batch[k].to(device, non_blocking=True)
|
|
||||||
info = policy(batch)
|
info = policy(batch)
|
||||||
if step % log_freq == 0:
|
if step % log_freq == 0:
|
||||||
num_samples = (step + 1) * cfg.batch_size
|
print(f"step: {step} loss: {info['loss']:.3f} update_time: {info['update_s']:.3f} (seconds)")
|
||||||
loss = info["loss"]
|
|
||||||
update_s = info["update_s"]
|
|
||||||
print(
|
|
||||||
f"step: {step} samples: {num_samples} loss: {loss:.3f} update_time: {update_s:.3f} (seconds)"
|
|
||||||
)
|
|
||||||
step += 1
|
step += 1
|
||||||
if step >= training_steps:
|
if step >= training_steps:
|
||||||
done = True
|
done = True
|
||||||
|
|
|
@ -65,12 +65,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
|
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
|
||||||
"""
|
"""
|
||||||
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
|
Args:
|
||||||
|
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||||
|
configuration class is used.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if getattr(cfg, "n_obs_steps", 1) != 1:
|
if cfg is None:
|
||||||
|
cfg = ActionChunkingTransformerConfig()
|
||||||
|
if cfg.n_obs_steps != 1:
|
||||||
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
|
|
|
@ -33,8 +33,6 @@ from lerobot.common.policies.utils import (
|
||||||
populate_queues,
|
populate_queues,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPolicy(nn.Module):
|
class DiffusionPolicy(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
@ -44,8 +42,17 @@ class DiffusionPolicy(nn.Module):
|
||||||
|
|
||||||
name = "diffusion"
|
name = "diffusion"
|
||||||
|
|
||||||
def __init__(self, cfg: DiffusionConfig, lr_scheduler_num_training_steps: int):
|
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||||
|
configuration class is used.
|
||||||
|
"""
|
||||||
|
# TODO(alexander-soare): LR scheduler will be removed.
|
||||||
|
assert lr_scheduler_num_training_steps > 0
|
||||||
|
if cfg is None:
|
||||||
|
cfg = DiffusionConfig()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||||
|
|
|
@ -40,7 +40,8 @@ def test_examples_3_and_2():
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
exec(file_contents)
|
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
|
||||||
|
exec(file_contents, {})
|
||||||
|
|
||||||
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
|
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
|
||||||
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
||||||
|
|
Loading…
Reference in New Issue