Add rlpd tricks
This commit is contained in:
parent
0ffc0a7170
commit
278b56bce9
|
@ -266,6 +266,7 @@ class SACPolicy(
|
||||||
|
|
||||||
# critics subsample size
|
# critics subsample size
|
||||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||||
|
if self.config.use_backup_entropy:
|
||||||
min_q = min_q - (temperature * next_log_probs)
|
min_q = min_q - (temperature * next_log_probs)
|
||||||
|
|
||||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||||
|
|
|
@ -30,9 +30,10 @@ from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
||||||
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset, LeRobotDataset
|
||||||
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
||||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
|
@ -64,7 +65,6 @@ def make_optimizers_and_scheduler(cfg, policy):
|
||||||
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
|
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
|
||||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
||||||
optimizers = {
|
optimizers = {
|
||||||
"actor": optimizer_actor,
|
"actor": optimizer_actor,
|
||||||
"critic": optimizer_critic,
|
"critic": optimizer_critic,
|
||||||
|
@ -136,6 +136,105 @@ class ReplayBuffer:
|
||||||
)
|
)
|
||||||
self.position = (self.position + 1) % self.capacity
|
self.position = (self.position + 1) % self.capacity
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_lerobot_dataset(
|
||||||
|
cls,
|
||||||
|
lerobot_dataset: LeRobotDataset,
|
||||||
|
device: str = "cuda:0",
|
||||||
|
state_keys: Optional[Sequence[str]] = None,
|
||||||
|
) -> "ReplayBuffer":
|
||||||
|
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
|
||||||
|
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||||
|
for data in list_transition:
|
||||||
|
replay_buffer.add(
|
||||||
|
state=data["state"],
|
||||||
|
action=data["action"],
|
||||||
|
reward=data["reward"],
|
||||||
|
next_state=data["next_state"],
|
||||||
|
done=data["done"],
|
||||||
|
)
|
||||||
|
return replay_buffer
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _lerobotdataset_to_transitions(
|
||||||
|
dataset: LeRobotDataset,
|
||||||
|
state_keys: Optional[Sequence[str]] = None,
|
||||||
|
) -> list[Transition]:
|
||||||
|
"""
|
||||||
|
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (LeRobotDataset):
|
||||||
|
The dataset to convert. Each item in the dataset is expected to have
|
||||||
|
at least the following keys:
|
||||||
|
{
|
||||||
|
"action": ...
|
||||||
|
"next.reward": ...
|
||||||
|
"next.done": ...
|
||||||
|
"episode_index": ...
|
||||||
|
}
|
||||||
|
plus whatever your 'state_keys' specify.
|
||||||
|
|
||||||
|
state_keys (Optional[Sequence[str]]):
|
||||||
|
The dataset keys to include in 'state' and 'next_state'. Their names
|
||||||
|
will be kept as-is in the output transitions. E.g.
|
||||||
|
["observation.state", "observation.environment_state"].
|
||||||
|
If None, you must handle or define default keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
transitions (List[Transition]):
|
||||||
|
A list of Transition dictionaries with the same length as `dataset`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# If not provided, you can either raise an error or define a default:
|
||||||
|
if state_keys is None:
|
||||||
|
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
|
||||||
|
|
||||||
|
transitions: list[Transition] = []
|
||||||
|
num_frames = len(dataset)
|
||||||
|
|
||||||
|
for i in tqdm(range(num_frames)):
|
||||||
|
current_sample = dataset[i]
|
||||||
|
|
||||||
|
# ----- 1) Current state -----
|
||||||
|
current_state: dict[str, torch.Tensor] = {}
|
||||||
|
for key in state_keys:
|
||||||
|
val = current_sample[key]
|
||||||
|
current_state[key] = val.unsqueeze(0) # Add batch dimension
|
||||||
|
|
||||||
|
# ----- 2) Action -----
|
||||||
|
action = current_sample["action"].unsqueeze(0) # Add batch dimension
|
||||||
|
|
||||||
|
# ----- 3) Reward and done -----
|
||||||
|
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||||
|
done = bool(current_sample["next.done"].item()) # ensure bool
|
||||||
|
|
||||||
|
# ----- 4) Next state -----
|
||||||
|
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
||||||
|
# Otherwise (done=True or next sample crosses to a new episode), next_state = current_state.
|
||||||
|
next_state = current_state # default
|
||||||
|
if not done and (i < num_frames - 1):
|
||||||
|
next_sample = dataset[i + 1]
|
||||||
|
if next_sample["episode_index"] == current_sample["episode_index"]:
|
||||||
|
# Build next_state from the same keys
|
||||||
|
next_state_data: dict[str, torch.Tensor] = {}
|
||||||
|
for key in state_keys:
|
||||||
|
val = next_sample[key]
|
||||||
|
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
|
||||||
|
next_state = next_state_data
|
||||||
|
|
||||||
|
# ----- Construct the Transition -----
|
||||||
|
transition = Transition(
|
||||||
|
state=current_state,
|
||||||
|
action=action,
|
||||||
|
reward=reward,
|
||||||
|
next_state=next_state,
|
||||||
|
done=done,
|
||||||
|
)
|
||||||
|
transitions.append(transition)
|
||||||
|
|
||||||
|
return transitions
|
||||||
|
|
||||||
def sample(self, batch_size: int) -> BatchTransition:
|
def sample(self, batch_size: int) -> BatchTransition:
|
||||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||||
list_of_transitions = random.sample(self.memory, batch_size)
|
list_of_transitions = random.sample(self.memory, batch_size)
|
||||||
|
@ -177,6 +276,32 @@ class ReplayBuffer:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def concatenate_batch_transitions(
|
||||||
|
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
|
||||||
|
) -> BatchTransition:
|
||||||
|
"""Be careful it change the left_batch_transitions in place"""
|
||||||
|
left_batch_transitions["state"] = {
|
||||||
|
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
|
||||||
|
for key in left_batch_transitions["state"]
|
||||||
|
}
|
||||||
|
left_batch_transitions["action"] = torch.cat(
|
||||||
|
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
||||||
|
)
|
||||||
|
left_batch_transitions["reward"] = torch.cat(
|
||||||
|
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
||||||
|
)
|
||||||
|
left_batch_transitions["next_state"] = {
|
||||||
|
key: torch.cat(
|
||||||
|
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
|
||||||
|
)
|
||||||
|
for key in left_batch_transitions["next_state"]
|
||||||
|
}
|
||||||
|
left_batch_transitions["done"] = torch.cat(
|
||||||
|
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
||||||
|
)
|
||||||
|
return left_batch_transitions
|
||||||
|
|
||||||
|
|
||||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||||
if out_dir is None:
|
if out_dir is None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -186,9 +311,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
init_logging()
|
init_logging()
|
||||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||||
|
|
||||||
if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
|
|
||||||
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
|
|
||||||
|
|
||||||
# Create an env dedicated to online episodes collection from policy rollout.
|
# Create an env dedicated to online episodes collection from policy rollout.
|
||||||
# online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
|
# online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
|
||||||
# NOTE: Off policy algorithm are efficient enought to use a single environment
|
# NOTE: Off policy algorithm are efficient enought to use a single environment
|
||||||
|
@ -250,6 +372,20 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
replay_buffer = ReplayBuffer(
|
replay_buffer = ReplayBuffer(
|
||||||
capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys()
|
capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
breakpoint()
|
||||||
|
batch_size = cfg.training.batch_size
|
||||||
|
# if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
|
||||||
|
# raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
|
||||||
|
if cfg.dataset_repo_id is not None:
|
||||||
|
logging.info("make_dataset offline buffer")
|
||||||
|
offline_dataset = make_dataset(cfg)
|
||||||
|
logging.info("Convertion to a offline replay buffer")
|
||||||
|
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||||
|
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
|
||||||
|
)
|
||||||
|
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||||
|
|
||||||
# NOTE: For the moment we will solely handle the case of a single environment
|
# NOTE: For the moment we will solely handle the case of a single environment
|
||||||
sum_reward_episode = 0
|
sum_reward_episode = 0
|
||||||
|
|
||||||
|
@ -285,7 +421,33 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
obs = next_obs
|
obs = next_obs
|
||||||
|
|
||||||
if interaction_step >= cfg.training.online_step_before_learning:
|
if interaction_step >= cfg.training.online_step_before_learning:
|
||||||
batch = replay_buffer.sample(cfg.training.batch_size)
|
for _ in range(cfg.policy.utd_ratio - 1):
|
||||||
|
batch = replay_buffer.sample(batch_size)
|
||||||
|
if cfg.dataset_repo_id is not None:
|
||||||
|
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||||
|
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||||
|
|
||||||
|
actions = batch["action"]
|
||||||
|
rewards = batch["reward"]
|
||||||
|
observations = batch["state"]
|
||||||
|
next_observations = batch["next_state"]
|
||||||
|
done = batch["done"]
|
||||||
|
|
||||||
|
loss_critic = policy.compute_loss_critic(
|
||||||
|
observations=observations,
|
||||||
|
actions=actions,
|
||||||
|
rewards=rewards,
|
||||||
|
next_observations=next_observations,
|
||||||
|
done=done,
|
||||||
|
)
|
||||||
|
optimizers["critic"].zero_grad()
|
||||||
|
loss_critic.backward()
|
||||||
|
optimizers["critic"].step()
|
||||||
|
|
||||||
|
batch = replay_buffer.sample(batch_size)
|
||||||
|
if cfg.dataset_repo_id is not None:
|
||||||
|
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||||
|
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||||
# 'observation.state', 'action', 'next.reward', 'next.done'
|
# 'observation.state', 'action', 'next.reward', 'next.done'
|
||||||
# TODO: (azouitine) interface to refine
|
# TODO: (azouitine) interface to refine
|
||||||
# TODO: At some point we should find a way to normalize the inputs
|
# TODO: At some point we should find a way to normalize the inputs
|
||||||
|
|
Loading…
Reference in New Issue