Add rlpd tricks

This commit is contained in:
Adil Zouitine 2025-01-15 15:49:24 +01:00
parent 0ffc0a7170
commit 278b56bce9
2 changed files with 170 additions and 7 deletions

View File

@ -266,7 +266,8 @@ class SACPolicy(
# critics subsample size # critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation min_q, _ = q_targets.min(dim=0) # Get values from min operation
min_q = min_q - (temperature * next_log_probs) if self.config.use_backup_entropy:
min_q = min_q - (temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q td_target = rewards + (1 - done) * self.config.discount * min_q

View File

@ -30,9 +30,10 @@ from omegaconf import DictConfig, ListConfig, OmegaConf
from termcolor import colored from termcolor import colored
from torch import nn from torch import nn
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from tqdm import tqdm
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset, LeRobotDataset
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
@ -64,7 +65,6 @@ def make_optimizers_and_scheduler(cfg, policy):
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module # We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
lr_scheduler = None lr_scheduler = None
optimizers = { optimizers = {
"actor": optimizer_actor, "actor": optimizer_actor,
"critic": optimizer_critic, "critic": optimizer_critic,
@ -136,6 +136,105 @@ class ReplayBuffer:
) )
self.position = (self.position + 1) % self.capacity self.position = (self.position + 1) % self.capacity
@classmethod
def from_lerobot_dataset(
cls,
lerobot_dataset: LeRobotDataset,
device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None,
) -> "ReplayBuffer":
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
for data in list_transition:
replay_buffer.add(
state=data["state"],
action=data["action"],
reward=data["reward"],
next_state=data["next_state"],
done=data["done"],
)
return replay_buffer
@staticmethod
def _lerobotdataset_to_transitions(
dataset: LeRobotDataset,
state_keys: Optional[Sequence[str]] = None,
) -> list[Transition]:
"""
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
Args:
dataset (LeRobotDataset):
The dataset to convert. Each item in the dataset is expected to have
at least the following keys:
{
"action": ...
"next.reward": ...
"next.done": ...
"episode_index": ...
}
plus whatever your 'state_keys' specify.
state_keys (Optional[Sequence[str]]):
The dataset keys to include in 'state' and 'next_state'. Their names
will be kept as-is in the output transitions. E.g.
["observation.state", "observation.environment_state"].
If None, you must handle or define default keys.
Returns:
transitions (List[Transition]):
A list of Transition dictionaries with the same length as `dataset`.
"""
# If not provided, you can either raise an error or define a default:
if state_keys is None:
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
transitions: list[Transition] = []
num_frames = len(dataset)
for i in tqdm(range(num_frames)):
current_sample = dataset[i]
# ----- 1) Current state -----
current_state: dict[str, torch.Tensor] = {}
for key in state_keys:
val = current_sample[key]
current_state[key] = val.unsqueeze(0) # Add batch dimension
# ----- 2) Action -----
action = current_sample["action"].unsqueeze(0) # Add batch dimension
# ----- 3) Reward and done -----
reward = float(current_sample["next.reward"].item()) # ensure float
done = bool(current_sample["next.done"].item()) # ensure bool
# ----- 4) Next state -----
# If not done and the next sample is in the same episode, we pull the next sample's state.
# Otherwise (done=True or next sample crosses to a new episode), next_state = current_state.
next_state = current_state # default
if not done and (i < num_frames - 1):
next_sample = dataset[i + 1]
if next_sample["episode_index"] == current_sample["episode_index"]:
# Build next_state from the same keys
next_state_data: dict[str, torch.Tensor] = {}
for key in state_keys:
val = next_sample[key]
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
next_state = next_state_data
# ----- Construct the Transition -----
transition = Transition(
state=current_state,
action=action,
reward=reward,
next_state=next_state,
done=done,
)
transitions.append(transition)
return transitions
def sample(self, batch_size: int) -> BatchTransition: def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors.""" """Sample a random batch of transitions and collate them into batched tensors."""
list_of_transitions = random.sample(self.memory, batch_size) list_of_transitions = random.sample(self.memory, batch_size)
@ -177,6 +276,32 @@ class ReplayBuffer:
) )
def concatenate_batch_transitions(
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
) -> BatchTransition:
"""Be careful it change the left_batch_transitions in place"""
left_batch_transitions["state"] = {
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
for key in left_batch_transitions["state"]
}
left_batch_transitions["action"] = torch.cat(
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
)
left_batch_transitions["reward"] = torch.cat(
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
)
left_batch_transitions["next_state"] = {
key: torch.cat(
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
)
for key in left_batch_transitions["next_state"]
}
left_batch_transitions["done"] = torch.cat(
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
)
return left_batch_transitions
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
@ -186,9 +311,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
init_logging() init_logging()
logging.info(pformat(OmegaConf.to_container(cfg))) logging.info(pformat(OmegaConf.to_container(cfg)))
if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
# Create an env dedicated to online episodes collection from policy rollout. # Create an env dedicated to online episodes collection from policy rollout.
# online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size) # online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
# NOTE: Off policy algorithm are efficient enought to use a single environment # NOTE: Off policy algorithm are efficient enought to use a single environment
@ -250,6 +372,20 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
replay_buffer = ReplayBuffer( replay_buffer = ReplayBuffer(
capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys() capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys()
) )
breakpoint()
batch_size = cfg.training.batch_size
# if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
# raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
# NOTE: For the moment we will solely handle the case of a single environment # NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0 sum_reward_episode = 0
@ -285,7 +421,33 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
obs = next_obs obs = next_obs
if interaction_step >= cfg.training.online_step_before_learning: if interaction_step >= cfg.training.online_step_before_learning:
batch = replay_buffer.sample(cfg.training.batch_size) for _ in range(cfg.policy.utd_ratio - 1):
batch = replay_buffer.sample(batch_size)
if cfg.dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(batch, batch_offline)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
optimizers["critic"].step()
batch = replay_buffer.sample(batch_size)
if cfg.dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(batch, batch_offline)
# 'observation.state', 'action', 'next.reward', 'next.done' # 'observation.state', 'action', 'next.reward', 'next.done'
# TODO: (azouitine) interface to refine # TODO: (azouitine) interface to refine
# TODO: At some point we should find a way to normalize the inputs # TODO: At some point we should find a way to normalize the inputs