# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """This scripts demonstrates how to train Diffusion Policy on the PushT environment. Once you have trained a model with this script, you can try to evaluate it on examples/2_evaluate_pretrained_policy.py """ from pathlib import Path import torch from lerobot.common.datasets.lerobot_dataset import ( LeRobotDataset, LeRobotDatasetMetadata, ) from lerobot.common.datasets.utils import dataset_to_policy_features from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.configs.types import FeatureType def main(): # Create a directory to store the training checkpoint. output_directory = Path("outputs/train/example_pusht_diffusion") output_directory.mkdir(parents=True, exist_ok=True) # # Select your device device = torch.device("cuda") # Number of offline training steps (we'll only do offline training for this example.) # Adjust as you prefer. 5000 steps are needed to get something worth evaluating. training_steps = 5000 log_freq = 1 # When starting from scratch (i.e. not from a pretrained policy), we need to specify 2 things before # creating the policy: # - input/output shapes: to properly size the policy # - dataset stats: for normalization and denormalization of input/outputs dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht") features = dataset_to_policy_features(dataset_metadata.features) output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} input_features = {key: ft for key, ft in features.items() if key not in output_features} # Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example, # we'll just use the defaults and so no arguments other than input/output features need to be passed. cfg = DiffusionConfig(input_features=input_features, output_features=output_features) # We can now instantiate our policy with this config and the dataset stats. policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats) policy.train() policy.to(device) # Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames # which can differ for inputs, outputs and rewards (if there are some). delta_timestamps = { "observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices], "observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices], "action": [i / dataset_metadata.fps for i in cfg.action_delta_indices], } # In this case with the standard configuration for Diffusion Policy, it is equivalent to this: delta_timestamps = { # Load the previous image and state at -0.1 seconds before current frame, # then load current image and state corresponding to 0.0 second. "observation.image": [-0.1, 0.0], "observation.state": [-0.1, 0.0], # Load the previous action (-0.1), the next action to be executed (0.0), # and 14 future actions with a 0.1 seconds spacing. All these actions will be # used to supervise the policy. "action": [ -0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, ], } # We can then instantiate the dataset with these delta_timestamps configuration. dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps) # Then we create our optimizer and dataloader for offline training. optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4) dataloader = torch.utils.data.DataLoader( dataset, num_workers=4, batch_size=64, shuffle=True, pin_memory=device.type != "cpu", drop_last=True, ) # Run training loop. step = 0 done = False while not done: for batch in dataloader: batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} loss, _ = policy.forward(batch) loss.backward() optimizer.step() optimizer.zero_grad() if step % log_freq == 0: print(f"step: {step} loss: {loss.item():.3f}") step += 1 if step >= training_steps: done = True break # Save a policy checkpoint. policy.save_pretrained(output_directory) if __name__ == "__main__": main()