121 lines
5.0 KiB
Python
121 lines
5.0 KiB
Python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
|
|
|
Once you have trained a model with this script, you can try to evaluate it on
|
|
examples/2_evaluate_pretrained_policy.py
|
|
"""
|
|
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
|
from lerobot.common.datasets.utils import dataset_to_policy_features
|
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
|
from lerobot.configs.types import FeatureType
|
|
|
|
|
|
def main():
|
|
# Create a directory to store the training checkpoint.
|
|
output_directory = Path("outputs/train/example_pusht_diffusion")
|
|
output_directory.mkdir(parents=True, exist_ok=True)
|
|
|
|
# # Select your device
|
|
device = torch.device("cuda")
|
|
|
|
# Number of offline training steps (we'll only do offline training for this example.)
|
|
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
|
training_steps = 5000
|
|
log_freq = 1
|
|
|
|
# When starting from scratch (i.e. not from a pretrained policy), we need to specify 2 things before
|
|
# creating the policy:
|
|
# - input/output shapes: to properly size the policy
|
|
# - dataset stats: for normalization and denormalization of input/outputs
|
|
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
|
|
features = dataset_to_policy_features(dataset_metadata.features)
|
|
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
|
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
|
|
|
# Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example,
|
|
# we'll just use the defaults and so no arguments other than input/output features need to be passed.
|
|
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
|
|
|
|
# We can now instantiate our policy with this config and the dataset stats.
|
|
policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats)
|
|
policy.train()
|
|
policy.to(device)
|
|
|
|
# Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames
|
|
# which can differ for inputs, outputs and rewards (if there are some).
|
|
delta_timestamps = {
|
|
"observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
|
|
"observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
|
|
"action": [i / dataset_metadata.fps for i in cfg.action_delta_indices],
|
|
}
|
|
|
|
# In this case with the standard configuration for Diffusion Policy, it is equivalent to this:
|
|
delta_timestamps = {
|
|
# Load the previous image and state at -0.1 seconds before current frame,
|
|
# then load current image and state corresponding to 0.0 second.
|
|
"observation.image": [-0.1, 0.0],
|
|
"observation.state": [-0.1, 0.0],
|
|
# Load the previous action (-0.1), the next action to be executed (0.0),
|
|
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
|
# used to supervise the policy.
|
|
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
|
}
|
|
|
|
# We can then instantiate the dataset with these delta_timestamps configuration.
|
|
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
|
|
|
|
# Then we create our optimizer and dataloader for offline training.
|
|
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
|
|
dataloader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
num_workers=4,
|
|
batch_size=64,
|
|
shuffle=True,
|
|
pin_memory=device.type != "cpu",
|
|
drop_last=True,
|
|
)
|
|
|
|
# Run training loop.
|
|
step = 0
|
|
done = False
|
|
while not done:
|
|
for batch in dataloader:
|
|
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
|
loss, _ = policy.forward(batch)
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
if step % log_freq == 0:
|
|
print(f"step: {step} loss: {loss.item():.3f}")
|
|
step += 1
|
|
if step >= training_steps:
|
|
done = True
|
|
break
|
|
|
|
# Save a policy checkpoint.
|
|
policy.save_pretrained(output_directory)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|