From 921ed960fb174d3a7e4ca753ed71a26d1b830ac2 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Wed, 15 Jan 2025 15:49:24 +0100 Subject: [PATCH] Add rlpd tricks --- lerobot/common/policies/sac/modeling_sac.py | 3 +- lerobot/scripts/train_sac.py | 174 +++++++++++++++++++- 2 files changed, 170 insertions(+), 7 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index f2d10ae5..e5173e04 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -266,7 +266,8 @@ class SACPolicy( # critics subsample size min_q, _ = q_targets.min(dim=0) # Get values from min operation - min_q = min_q - (temperature * next_log_probs) + if self.config.use_backup_entropy: + min_q = min_q - (temperature * next_log_probs) td_target = rewards + (1 - done) * self.config.discount * min_q diff --git a/lerobot/scripts/train_sac.py b/lerobot/scripts/train_sac.py index 30891db9..942a19ab 100644 --- a/lerobot/scripts/train_sac.py +++ b/lerobot/scripts/train_sac.py @@ -30,9 +30,10 @@ from omegaconf import DictConfig, ListConfig, OmegaConf from termcolor import colored from torch import nn from torch.cuda.amp import GradScaler +from tqdm import tqdm from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps -from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset +from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset, LeRobotDataset from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.utils import cycle @@ -64,7 +65,6 @@ def make_optimizers_and_scheduler(cfg, policy): # We wrap policy log temperature in list because this is a torch tensor and not a nn.Module optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr) lr_scheduler = None - optimizers = { "actor": optimizer_actor, "critic": optimizer_critic, @@ -136,6 +136,105 @@ class ReplayBuffer: ) self.position = (self.position + 1) % self.capacity + @classmethod + def from_lerobot_dataset( + cls, + lerobot_dataset: LeRobotDataset, + device: str = "cuda:0", + state_keys: Optional[Sequence[str]] = None, + ) -> "ReplayBuffer": + replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys) + list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys) + for data in list_transition: + replay_buffer.add( + state=data["state"], + action=data["action"], + reward=data["reward"], + next_state=data["next_state"], + done=data["done"], + ) + return replay_buffer + + @staticmethod + def _lerobotdataset_to_transitions( + dataset: LeRobotDataset, + state_keys: Optional[Sequence[str]] = None, + ) -> list[Transition]: + """ + Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions. + + Args: + dataset (LeRobotDataset): + The dataset to convert. Each item in the dataset is expected to have + at least the following keys: + { + "action": ... + "next.reward": ... + "next.done": ... + "episode_index": ... + } + plus whatever your 'state_keys' specify. + + state_keys (Optional[Sequence[str]]): + The dataset keys to include in 'state' and 'next_state'. Their names + will be kept as-is in the output transitions. E.g. + ["observation.state", "observation.environment_state"]. + If None, you must handle or define default keys. + + Returns: + transitions (List[Transition]): + A list of Transition dictionaries with the same length as `dataset`. + """ + + # If not provided, you can either raise an error or define a default: + if state_keys is None: + raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.") + + transitions: list[Transition] = [] + num_frames = len(dataset) + + for i in tqdm(range(num_frames)): + current_sample = dataset[i] + + # ----- 1) Current state ----- + current_state: dict[str, torch.Tensor] = {} + for key in state_keys: + val = current_sample[key] + current_state[key] = val.unsqueeze(0) # Add batch dimension + + # ----- 2) Action ----- + action = current_sample["action"].unsqueeze(0) # Add batch dimension + + # ----- 3) Reward and done ----- + reward = float(current_sample["next.reward"].item()) # ensure float + done = bool(current_sample["next.done"].item()) # ensure bool + + # ----- 4) Next state ----- + # If not done and the next sample is in the same episode, we pull the next sample's state. + # Otherwise (done=True or next sample crosses to a new episode), next_state = current_state. + next_state = current_state # default + if not done and (i < num_frames - 1): + next_sample = dataset[i + 1] + if next_sample["episode_index"] == current_sample["episode_index"]: + # Build next_state from the same keys + next_state_data: dict[str, torch.Tensor] = {} + for key in state_keys: + val = next_sample[key] + next_state_data[key] = val.unsqueeze(0) # Add batch dimension + next_state = next_state_data + + # ----- Construct the Transition ----- + transition = Transition( + state=current_state, + action=action, + reward=reward, + next_state=next_state, + done=done, + ) + transitions.append(transition) + + return transitions + def sample(self, batch_size: int) -> BatchTransition: """Sample a random batch of transitions and collate them into batched tensors.""" list_of_transitions = random.sample(self.memory, batch_size) @@ -177,6 +276,32 @@ class ReplayBuffer: ) +def concatenate_batch_transitions( + left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition +) -> BatchTransition: + """Be careful it change the left_batch_transitions in place""" + left_batch_transitions["state"] = { + key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0) + for key in left_batch_transitions["state"] + } + left_batch_transitions["action"] = torch.cat( + [left_batch_transitions["action"], right_batch_transition["action"]], dim=0 + ) + left_batch_transitions["reward"] = torch.cat( + [left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0 + ) + left_batch_transitions["next_state"] = { + key: torch.cat( + [left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0 + ) + for key in left_batch_transitions["next_state"] + } + left_batch_transitions["done"] = torch.cat( + [left_batch_transitions["done"], right_batch_transition["done"]], dim=0 + ) + return left_batch_transitions + + def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): if out_dir is None: raise NotImplementedError() @@ -186,9 +311,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No init_logging() logging.info(pformat(OmegaConf.to_container(cfg))) - if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig): - raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.") - # Create an env dedicated to online episodes collection from policy rollout. # online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size) # NOTE: Off policy algorithm are efficient enought to use a single environment @@ -250,6 +372,20 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No replay_buffer = ReplayBuffer( capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys() ) + + breakpoint() + batch_size = cfg.training.batch_size + # if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig): + # raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.") + if cfg.dataset_repo_id is not None: + logging.info("make_dataset offline buffer") + offline_dataset = make_dataset(cfg) + logging.info("Convertion to a offline replay buffer") + offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( + offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys() + ) + batch_size: int = batch_size // 2 # We will sample from both replay buffer + # NOTE: For the moment we will solely handle the case of a single environment sum_reward_episode = 0 @@ -285,7 +421,33 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No obs = next_obs if interaction_step >= cfg.training.online_step_before_learning: - batch = replay_buffer.sample(cfg.training.batch_size) + for _ in range(cfg.policy.utd_ratio - 1): + batch = replay_buffer.sample(batch_size) + if cfg.dataset_repo_id is not None: + batch_offline = offline_replay_buffer.sample(batch_size) + batch = concatenate_batch_transitions(batch, batch_offline) + + actions = batch["action"] + rewards = batch["reward"] + observations = batch["state"] + next_observations = batch["next_state"] + done = batch["done"] + + loss_critic = policy.compute_loss_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + ) + optimizers["critic"].zero_grad() + loss_critic.backward() + optimizers["critic"].step() + + batch = replay_buffer.sample(batch_size) + if cfg.dataset_repo_id is not None: + batch_offline = offline_replay_buffer.sample(batch_size) + batch = concatenate_batch_transitions(batch, batch_offline) # 'observation.state', 'action', 'next.reward', 'next.done' # TODO: (azouitine) interface to refine # TODO: At some point we should find a way to normalize the inputs