533 lines
21 KiB
Python
533 lines
21 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
import time
|
|
from contextlib import nullcontext
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
from pprint import pformat
|
|
import random
|
|
from typing import Optional, Sequence, TypedDict
|
|
|
|
import hydra
|
|
import numpy as np
|
|
import torch
|
|
from deepdiff import DeepDiff
|
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
|
from termcolor import colored
|
|
from torch import nn
|
|
from torch.cuda.amp import GradScaler
|
|
from tqdm import tqdm
|
|
|
|
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
|
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset, LeRobotDataset
|
|
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
|
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
|
from lerobot.common.datasets.utils import cycle
|
|
from lerobot.common.envs.factory import make_env
|
|
from lerobot.common.envs.utils import preprocess_observation
|
|
from lerobot.common.logger import Logger, log_output_dir
|
|
from lerobot.common.policies.factory import make_policy
|
|
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
|
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
|
from lerobot.common.policies.utils import get_device_from_parameters
|
|
from lerobot.common.utils.utils import (
|
|
format_big_number,
|
|
get_safe_torch_device,
|
|
init_hydra_config,
|
|
init_logging,
|
|
set_global_seed,
|
|
)
|
|
from lerobot.scripts.eval import eval_policy
|
|
|
|
|
|
def make_optimizers_and_scheduler(cfg, policy):
|
|
optimizer_actor = torch.optim.Adam(
|
|
params=policy.actor.parameters(),
|
|
lr=policy.config.actor_lr,
|
|
)
|
|
optimizer_critic = torch.optim.Adam(
|
|
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
|
)
|
|
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
|
|
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
|
lr_scheduler = None
|
|
optimizers = {
|
|
"actor": optimizer_actor,
|
|
"critic": optimizer_critic,
|
|
"temperature": optimizer_temperature,
|
|
}
|
|
return optimizers, lr_scheduler
|
|
|
|
|
|
# def update_policy(policy, batch, optimizers, grad_clip_norm):
|
|
|
|
# NOTE: This is temporary, online buffer or query lerobot dataset is not performant enough yet
|
|
|
|
|
|
class Transition(TypedDict):
|
|
state: dict[str, torch.Tensor]
|
|
action: torch.Tensor
|
|
reward: float
|
|
next_state: dict[str, torch.Tensor]
|
|
done: bool
|
|
complementary_info: dict[str, torch.Tensor] = None
|
|
|
|
|
|
class BatchTransition(TypedDict):
|
|
state: dict[str, torch.Tensor]
|
|
action: torch.Tensor
|
|
reward: torch.Tensor
|
|
next_state: dict[str, torch.Tensor]
|
|
done: torch.Tensor
|
|
|
|
|
|
class ReplayBuffer:
|
|
def __init__(self, capacity: int, device: str = "cuda:0", state_keys: Optional[Sequence[str]] = None):
|
|
"""
|
|
Args:
|
|
capacity (int): Maximum number of transitions to store in the buffer.
|
|
device (str): The device where the tensors will be moved ("cuda:0" or "cpu").
|
|
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
|
|
"""
|
|
self.capacity = capacity
|
|
self.device = device
|
|
self.memory: list[Transition] = []
|
|
self.position = 0
|
|
|
|
# If no state_keys provided, default to an empty list
|
|
# (you can handle this differently if needed)
|
|
self.state_keys = state_keys if state_keys is not None else []
|
|
|
|
def add(
|
|
self,
|
|
state: dict[str, torch.Tensor],
|
|
action: torch.Tensor,
|
|
reward: float,
|
|
next_state: dict[str, torch.Tensor],
|
|
done: bool,
|
|
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
|
):
|
|
"""Saves a transition."""
|
|
if len(self.memory) < self.capacity:
|
|
self.memory.append(None)
|
|
|
|
# Create and store the Transition
|
|
self.memory[self.position] = Transition(
|
|
state=state,
|
|
action=action,
|
|
reward=reward,
|
|
next_state=next_state,
|
|
done=done,
|
|
complementary_info=complementary_info,
|
|
)
|
|
self.position = (self.position + 1) % self.capacity
|
|
|
|
@classmethod
|
|
def from_lerobot_dataset(
|
|
cls,
|
|
lerobot_dataset: LeRobotDataset,
|
|
device: str = "cuda:0",
|
|
state_keys: Optional[Sequence[str]] = None,
|
|
) -> "ReplayBuffer":
|
|
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
|
|
# a replay buffer than from a lerobot dataset.
|
|
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
|
|
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
|
# Fill the replay buffer with the lerobot dataset transitions
|
|
for data in list_transition:
|
|
replay_buffer.add(
|
|
state=data["state"],
|
|
action=data["action"],
|
|
reward=data["reward"],
|
|
next_state=data["next_state"],
|
|
done=data["done"],
|
|
)
|
|
return replay_buffer
|
|
|
|
@staticmethod
|
|
def _lerobotdataset_to_transitions(
|
|
dataset: LeRobotDataset,
|
|
state_keys: Optional[Sequence[str]] = None,
|
|
) -> list[Transition]:
|
|
"""
|
|
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
|
|
|
|
Args:
|
|
dataset (LeRobotDataset):
|
|
The dataset to convert. Each item in the dataset is expected to have
|
|
at least the following keys:
|
|
{
|
|
"action": ...
|
|
"next.reward": ...
|
|
"next.done": ...
|
|
"episode_index": ...
|
|
}
|
|
plus whatever your 'state_keys' specify.
|
|
|
|
state_keys (Optional[Sequence[str]]):
|
|
The dataset keys to include in 'state' and 'next_state'. Their names
|
|
will be kept as-is in the output transitions. E.g.
|
|
["observation.state", "observation.environment_state"].
|
|
If None, you must handle or define default keys.
|
|
|
|
Returns:
|
|
transitions (List[Transition]):
|
|
A list of Transition dictionaries with the same length as `dataset`.
|
|
"""
|
|
|
|
# If not provided, you can either raise an error or define a default:
|
|
if state_keys is None:
|
|
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
|
|
|
|
transitions: list[Transition] = []
|
|
num_frames = len(dataset)
|
|
|
|
for i in tqdm(range(num_frames)):
|
|
current_sample = dataset[i]
|
|
|
|
# ----- 1) Current state -----
|
|
current_state: dict[str, torch.Tensor] = {}
|
|
for key in state_keys:
|
|
val = current_sample[key]
|
|
current_state[key] = val.unsqueeze(0) # Add batch dimension
|
|
|
|
# ----- 2) Action -----
|
|
action = current_sample["action"].unsqueeze(0) # Add batch dimension
|
|
|
|
# ----- 3) Reward and done -----
|
|
reward = float(current_sample["next.reward"].item()) # ensure float
|
|
done = bool(current_sample["next.done"].item()) # ensure bool
|
|
|
|
# ----- 4) Next state -----
|
|
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
|
# Otherwise (done=True or next sample crosses to a new episode), next_state = current_state.
|
|
next_state = current_state # default
|
|
if not done and (i < num_frames - 1):
|
|
next_sample = dataset[i + 1]
|
|
if next_sample["episode_index"] == current_sample["episode_index"]:
|
|
# Build next_state from the same keys
|
|
next_state_data: dict[str, torch.Tensor] = {}
|
|
for key in state_keys:
|
|
val = next_sample[key]
|
|
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
|
|
next_state = next_state_data
|
|
|
|
# ----- Construct the Transition -----
|
|
transition = Transition(
|
|
state=current_state,
|
|
action=action,
|
|
reward=reward,
|
|
next_state=next_state,
|
|
done=done,
|
|
)
|
|
transitions.append(transition)
|
|
|
|
return transitions
|
|
|
|
def sample(self, batch_size: int) -> BatchTransition:
|
|
"""Sample a random batch of transitions and collate them into batched tensors."""
|
|
list_of_transitions = random.sample(self.memory, batch_size)
|
|
|
|
# -- Build batched states --
|
|
batch_state = {}
|
|
for key in self.state_keys:
|
|
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
|
|
self.device
|
|
)
|
|
|
|
# -- Build batched actions --
|
|
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
|
|
|
|
# -- Build batched rewards --
|
|
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
|
|
self.device
|
|
)
|
|
|
|
# -- Build batched next states --
|
|
batch_next_state = {}
|
|
for key in self.state_keys:
|
|
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
|
|
self.device
|
|
)
|
|
|
|
# -- Build batched dones --
|
|
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
|
self.device
|
|
)
|
|
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
|
self.device
|
|
)
|
|
|
|
# Return a BatchTransition typed dict
|
|
return BatchTransition(
|
|
state=batch_state,
|
|
action=batch_actions,
|
|
reward=batch_rewards,
|
|
next_state=batch_next_state,
|
|
done=batch_dones,
|
|
)
|
|
|
|
|
|
def concatenate_batch_transitions(
|
|
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
|
|
) -> BatchTransition:
|
|
"""Be careful it change the left_batch_transitions in place"""
|
|
left_batch_transitions["state"] = {
|
|
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
|
|
for key in left_batch_transitions["state"]
|
|
}
|
|
left_batch_transitions["action"] = torch.cat(
|
|
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
|
)
|
|
left_batch_transitions["reward"] = torch.cat(
|
|
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
|
)
|
|
left_batch_transitions["next_state"] = {
|
|
key: torch.cat(
|
|
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
|
|
)
|
|
for key in left_batch_transitions["next_state"]
|
|
}
|
|
left_batch_transitions["done"] = torch.cat(
|
|
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
|
)
|
|
return left_batch_transitions
|
|
|
|
|
|
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
|
if out_dir is None:
|
|
raise NotImplementedError()
|
|
if job_name is None:
|
|
raise NotImplementedError()
|
|
|
|
init_logging()
|
|
logging.info(pformat(OmegaConf.to_container(cfg)))
|
|
|
|
# Create an env dedicated to online episodes collection from policy rollout.
|
|
# online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
|
|
# NOTE: Off policy algorithm are efficient enought to use a single environment
|
|
logging.info("make_env online")
|
|
online_env = make_env(cfg, n_envs=1)
|
|
|
|
if cfg.training.eval_freq > 0:
|
|
logging.info("make_env eval")
|
|
eval_env = make_env(cfg, n_envs=1)
|
|
|
|
# TODO: Add a way to resume training
|
|
|
|
# log metrics to terminal and wandb
|
|
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
|
|
|
set_global_seed(cfg.seed)
|
|
|
|
# Check device is available
|
|
device = get_safe_torch_device(cfg.device, log=True)
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
logging.info("make_policy")
|
|
# TODO: At some point we should just need make sac policy
|
|
policy: SACPolicy = make_policy(
|
|
hydra_cfg=cfg,
|
|
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
|
# Hack: But if we do online traning, we do not need dataset_stats
|
|
dataset_stats=None,
|
|
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
|
)
|
|
assert isinstance(policy, nn.Module)
|
|
|
|
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
|
|
|
# TODO: Handle resume
|
|
|
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
|
|
|
log_output_dir(out_dir)
|
|
logging.info(f"{cfg.env.task=}")
|
|
# TODO: Handle offline steps
|
|
# logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
|
|
logging.info(f"{cfg.training.online_steps=}")
|
|
# logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
|
|
# logging.info(f"{offline_dataset.num_episodes=}")
|
|
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
|
|
|
obs, info = online_env.reset()
|
|
|
|
obs = preprocess_observation(obs)
|
|
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
|
|
|
|
replay_buffer = ReplayBuffer(
|
|
capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys()
|
|
)
|
|
|
|
batch_size = cfg.training.batch_size
|
|
# if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
|
|
# raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
|
|
if cfg.dataset_repo_id is not None:
|
|
logging.info("make_dataset offline buffer")
|
|
offline_dataset = make_dataset(cfg)
|
|
logging.info("Convertion to a offline replay buffer")
|
|
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
|
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
|
|
)
|
|
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
|
|
|
# NOTE: For the moment we will solely handle the case of a single environment
|
|
sum_reward_episode = 0
|
|
|
|
for interaction_step in range(cfg.training.online_steps):
|
|
# NOTE: At some point we should use a wrapper to handle the observation
|
|
|
|
if interaction_step >= cfg.training.online_step_before_learning:
|
|
action = policy.select_action(batch=obs)
|
|
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
|
|
else:
|
|
action = online_env.action_space.sample()
|
|
next_obs, reward, done, truncated, info = online_env.step(action)
|
|
# HACK
|
|
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
|
|
|
|
next_obs = preprocess_observation(next_obs)
|
|
next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs}
|
|
sum_reward_episode += float(reward[0])
|
|
# Because we are using a single environment
|
|
# we can safely assume that the episode is done
|
|
if done[0] or truncated[0]:
|
|
logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
|
logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step)
|
|
sum_reward_episode = 0
|
|
if "final_info" in info:
|
|
if "is_success" in info["final_info"][0]:
|
|
logging.info(
|
|
f"Global step {interaction_step}: Episode success: {info['final_info'][0]['is_success']}"
|
|
)
|
|
if "coverage" in info["final_info"][0]:
|
|
logging.info(
|
|
f"Global step {interaction_step}: Episode final coverage: {info['final_info'][0]['coverage']} \n"
|
|
)
|
|
logger.log_dict({"Final coverage": info["final_info"][0]["coverage"]}, interaction_step)
|
|
|
|
replay_buffer.add(
|
|
state=obs,
|
|
action=action,
|
|
reward=float(reward[0]),
|
|
next_state=next_obs,
|
|
done=done[0],
|
|
)
|
|
obs = next_obs
|
|
|
|
if interaction_step >= cfg.training.online_step_before_learning:
|
|
for _ in range(cfg.policy.utd_ratio - 1):
|
|
batch = replay_buffer.sample(batch_size)
|
|
if cfg.dataset_repo_id is not None:
|
|
batch_offline = offline_replay_buffer.sample(batch_size)
|
|
batch = concatenate_batch_transitions(batch, batch_offline)
|
|
|
|
actions = batch["action"]
|
|
rewards = batch["reward"]
|
|
observations = batch["state"]
|
|
next_observations = batch["next_state"]
|
|
done = batch["done"]
|
|
|
|
loss_critic = policy.compute_loss_critic(
|
|
observations=observations,
|
|
actions=actions,
|
|
rewards=rewards,
|
|
next_observations=next_observations,
|
|
done=done,
|
|
)
|
|
optimizers["critic"].zero_grad()
|
|
loss_critic.backward()
|
|
optimizers["critic"].step()
|
|
|
|
batch = replay_buffer.sample(batch_size)
|
|
if cfg.dataset_repo_id is not None:
|
|
batch_offline = offline_replay_buffer.sample(batch_size)
|
|
batch = concatenate_batch_transitions(
|
|
left_batch_transitions=batch, right_batch_transition=batch_offline
|
|
)
|
|
# NOTE: We have to handle the normalization for the batch
|
|
# batch = policy.normalize_inputs(batch)
|
|
|
|
actions = batch["action"]
|
|
rewards = batch["reward"]
|
|
observations = batch["state"]
|
|
next_observations = batch["next_state"]
|
|
done = batch["done"]
|
|
|
|
loss_critic = policy.compute_loss_critic(
|
|
observations=observations,
|
|
actions=actions,
|
|
rewards=rewards,
|
|
next_observations=next_observations,
|
|
done=done,
|
|
)
|
|
optimizers["critic"].zero_grad()
|
|
loss_critic.backward()
|
|
optimizers["critic"].step()
|
|
|
|
training_infos = {}
|
|
training_infos["loss_critic"] = loss_critic.item()
|
|
|
|
if interaction_step % cfg.training.policy_update_freq == 0:
|
|
# TD3 Trick
|
|
for _ in range(cfg.training.policy_update_freq):
|
|
loss_actor = policy.compute_loss_actor(observations=observations)
|
|
|
|
optimizers["actor"].zero_grad()
|
|
loss_actor.backward()
|
|
optimizers["actor"].step()
|
|
|
|
training_infos["loss_actor"] = loss_actor.item()
|
|
|
|
loss_temperature = policy.compute_loss_temperature(observations=observations)
|
|
optimizers["temperature"].zero_grad()
|
|
loss_temperature.backward()
|
|
optimizers["temperature"].step()
|
|
|
|
training_infos["loss_temperature"] = loss_temperature.item()
|
|
|
|
if interaction_step % cfg.training.log_freq == 0:
|
|
logger.log_dict(training_infos, interaction_step, mode="train")
|
|
|
|
policy.update_target_networks()
|
|
|
|
|
|
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
|
def train_cli(cfg: dict):
|
|
train(
|
|
cfg,
|
|
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
|
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
|
)
|
|
|
|
|
|
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
|
|
from hydra import compose, initialize
|
|
|
|
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
|
initialize(config_path=config_path)
|
|
cfg = compose(config_name=config_name)
|
|
train(cfg, out_dir=out_dir, job_name=job_name)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
train_cli()
|