[WIP] correct sac implementation

This commit is contained in:
Adil Zouitine 2025-01-13 17:54:11 +01:00
parent 380b836eee
commit 73aa6c25f3
2 changed files with 1583 additions and 0 deletions

View File

@ -0,0 +1,592 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: (1) better device management
from collections import deque
from typing import Callable, Optional, Sequence, Tuple, Union
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.sac.configuration_sac import SACConfig
class SACPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "RL", "SAC"],
):
name = "sac"
def __init__(
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
):
super().__init__()
if config is None:
config = SACConfig()
self.config = config
if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
else:
self.normalize_inputs = nn.Identity()
# HACK: we need to pass the dataset_stats to the normalization functions
dataset_stats = dataset_stats or {
"action": {
"min": torch.tensor([-1.0, -1.0, -1.0, -1.0]),
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
}
}
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
encoder_critic = SACObservationEncoder(config)
encoder_actor = SACObservationEncoder(config)
# Define networks
critic_nets = []
for _ in range(config.num_critics):
critic_net = Critic(
encoder=encoder_critic,
network=MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
),
)
critic_nets.append(critic_net)
target_critic_nets = []
for _ in range(config.num_critics):
target_critic_net = Critic(
encoder=encoder_critic,
network=MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
),
)
target_critic_nets.append(target_critic_net)
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.actor = Policy(
encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
**config.policy_kwargs,
)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
# TODO: fix later device
# TODO: Handle the case where the temparameter is a fixed
self.log_alpha = torch.zeros(1, requires_grad=True, device="cpu")
self.temperature = self.log_alpha.exp().item()
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
queues are populated during rollout of the policy, they contain the n latest observations and actions
"""
self._queues = {
"observation.state": deque(maxlen=1),
"action": deque(maxlen=1),
}
if "observation.image" in self.config.input_shapes:
self._queues["observation.image"] = deque(maxlen=1)
if "observation.environment_state" in self.config.input_shapes:
self._queues["observation.environment_state"] = deque(maxlen=1)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
actions, _, _ = self.actor(batch)
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
def critic_forward(
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
) -> Tensor:
"""Forward pass through a critic network ensemble
Args:
observations: Dictionary of observations
actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics
Returns:
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = torch.stack([critic(observations, actions) for critic in critics])
return q_values
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats.
"""
# We have to actualize the value of the temperature because in the previous
self.temperature = self.log_alpha.exp().item()
temperature = self.temperature
batch = self.normalize_inputs(batch)
# batch shape is (b, 2, ...) where index 1 returns the current observation and
# the next observation for calculating the right td index.
# actions = batch["action"][:, 0]
actions = batch["action"]
rewards = batch["next.reward"][:, 0]
observations = {}
next_observations = {}
for k in batch:
if k.startswith("observation."):
observations[k] = batch[k][:, 0]
next_observations[k] = batch[k][:, 1]
done = batch["next.done"]
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations)
# 2- compute q targets
q_targets = self.critic_forward(next_observations, next_action_preds, use_target=True)
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q -= self.temperature * next_log_probs
td_target = rewards + self.config.discount * min_q * ~done
# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(1)
).sum()
actions_pi, log_probs, _ = self.actor(observations)
with torch.inference_mode():
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
loss = critics_loss + actor_loss + temperature_loss
return {
"critics_loss": critics_loss.item(),
"actor_loss": actor_loss.item(),
"mean_q_predicts": min_q_preds.mean().item(),
"min_q_predicts": min_q_preds.min().item(),
"max_q_predicts": min_q_preds.max().item(),
"temperature_loss": temperature_loss.item(),
"temperature": temperature,
"mean_log_probs": log_probs.mean().item(),
"min_log_probs": log_probs.min().item(),
"max_log_probs": log_probs.max().item(),
"td_target_mean": td_target.mean().item(),
"td_target_max": td_target.max().item(),
"action_mean": actions.mean().item(),
"entropy": log_probs.mean().item(),
"loss": loss,
}
def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations)
# 2- compute q targets
q_targets = self.critic_forward(next_observations, next_action_preds, use_target=True)
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q -= temperature * next_log_probs
td_target = rewards + self.config.discount * min_q * ~done
# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(1)
).sum()
return critics_loss
def compute_loss_temperature(self, observations) -> Tensor:
breakpoint()
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(self, observations) -> Tensor:
temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations)
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
return actor_loss
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
):
super().__init__()
self.activate_final = activate_final
layers = []
# First layer uses input_dim
layers.append(nn.Linear(input_dim, hidden_dims[0]))
# Add activation after first layer
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[0]))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
# Rest of the layers
for i in range(1, len(hidden_dims)):
layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
if i + 1 < len(hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[i]))
layers.append(
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class Critic(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
network: nn.Module,
init_final: Optional[float] = None,
device: str = "cpu",
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network = network
self.init_final = init_final
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
# Output layer
if init_final is not None:
self.output_layer = nn.Linear(out_features, 1)
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
self.output_layer = nn.Linear(out_features, 1)
orthogonal_init()(self.output_layer.weight)
self.to(self.device)
def forward(
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
) -> torch.Tensor:
# Move each tensor in observations to device
observations = {k: v.to(self.device) for k, v in observations.items()}
actions = actions.to(self.device)
obs_enc = observations if self.encoder is None else self.encoder(observations)
inputs = torch.cat([obs_enc, actions], dim=-1)
x = self.network(inputs)
value = self.output_layer(x)
return value.squeeze(-1)
class Policy(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
network: nn.Module,
action_dim: int,
log_std_min: float = -5,
log_std_max: float = 2,
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
use_tanh_squash: bool = False,
device: str = "cpu",
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network = network
self.action_dim = action_dim
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.use_tanh_squash = use_tanh_squash
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
# Mean layer
self.mean_layer = nn.Linear(out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.mean_layer.weight)
# Standard deviation layer or parameter
if fixed_std is None:
self.std_layer = nn.Linear(out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.std_layer.weight)
self.to(self.device)
def forward(
self,
observations: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
obs_enc = observations if self.encoder is None else self.encoder(observations)
# Get network outputs
outputs = self.network(obs_enc)
means = self.mean_layer(outputs)
# Compute standard deviations
if self.fixed_std is None:
log_std = self.std_layer(outputs)
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
else:
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else:
log_std = self.fixed_std.expand_as(means)
# uses tanh activation function to squash the action to be in the range of [-1, 1]
normal = torch.distributions.Normal(means, torch.exp(log_std))
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
if self.use_tanh_squash:
actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
else:
actions = x_t # No Tanh; raw Gaussian sample
log_probs = log_probs.sum(-1) # Sum over action dimensions
means = torch.tanh(means) if self.use_tanh_squash else means
return actions, log_probs, means
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations"""
observations = observations.to(self.device)
if self.encoder is not None:
with torch.inference_mode():
return self.encoder(observations)
return observations
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations.
TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders.
"""
def __init__(self, config: SACConfig):
"""
Creates encoders for pixel and/or state modalities.
"""
super().__init__()
self.config = config
if "observation.image" in config.input_shapes:
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
)
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
nn.Linear(config.input_shapes["observation.state"][0], config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(config.input_shapes["observation.environment_state"][0], config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
over all features.
"""
feat = []
# Concatenate all images along the channel dimension.
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
for image_key in image_keys:
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]))
if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
return torch.stack(feat, dim=0).mean(0)
@property
def output_dim(self) -> int:
"""Returns the dimension of the encoder output"""
return self.config.latent_dim
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cpu") -> nn.ModuleList:
"""Creates an ensemble of critic networks"""
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
return nn.ModuleList(critics).to(device)
# borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
Args:
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
(B, *), where * is any number of dimensions.
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
can be more than 1 dimensions, generally different from *.
Returns:
A return value from the callable reshaped to (**, *).
"""
if image_tensor.ndim == 4:
return fn(image_tensor)
start_dims = image_tensor.shape[:-3]
inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))

View File

@ -0,0 +1,991 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
from pprint import pformat
import random
from typing import Optional, Sequence, TypedDict
import hydra
import numpy as np
import torch
from deepdiff import DeepDiff
from omegaconf import DictConfig, ListConfig, OmegaConf
from termcolor import colored
from torch import nn
from torch.cuda.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_hydra_config,
init_logging,
set_global_seed,
)
from lerobot.scripts.eval import eval_policy
def make_optimizers_and_scheduler(cfg, policy):
optimizer_actor = torch.optim.Adam(
params=policy.actor.parameters(),
lr=policy.config.actor_lr,
)
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
)
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
"critic": optimizer_critic,
"temperature": optimizer_temperature,
}
return optimizers, lr_scheduler
# def update_policy(policy, batch, optimizers, grad_clip_norm):
# NOTE: This is temporary, online buffer or query lerobot dataset is not performant enough yet
class Transition(TypedDict):
state: dict[str, torch.Tensor]
action: torch.Tensor
reward: float
next_state: dict[str, torch.Tensor]
done: bool
complementary_info: dict[str, torch.Tensor] = None
class BatchTransition(TypedDict):
state: dict[str, torch.Tensor]
action: torch.Tensor
reward: torch.Tensor
next_state: dict[str, torch.Tensor]
done: torch.Tensor
class ReplayBuffer:
def __init__(self, capacity: int, device: str = "cuda:0", state_keys: Optional[Sequence[str]] = None):
"""
Args:
capacity (int): Maximum number of transitions to store in the buffer.
device (str): The device where the tensors will be moved ("cuda:0" or "cpu").
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
"""
self.capacity = capacity
self.device = device
self.memory: list[Transition] = []
self.position = 0
# If no state_keys provided, default to an empty list
# (you can handle this differently if needed)
self.state_keys = state_keys if state_keys is not None else []
def add(
self,
state: dict[str, torch.Tensor],
action: torch.Tensor,
reward: float,
next_state: dict[str, torch.Tensor],
done: bool,
complementary_info: Optional[dict[str, torch.Tensor]] = None,
):
"""Saves a transition."""
if len(self.memory) < self.capacity:
self.memory.append(None)
# Create and store the Transition
self.memory[self.position] = Transition(
state=state,
action=action,
reward=reward,
next_state=next_state,
done=done,
complementary_info=complementary_info,
)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors."""
list_of_transitions = random.sample(self.memory, batch_size)
# -- Build batched states --
batch_state = {}
for key in self.state_keys:
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
# -- Build batched actions --
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
# -- Build batched rewards --
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
# -- Build batched next states --
batch_next_state = {}
for key in self.state_keys:
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
# -- Build batched dones --
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.bool).to(self.device)
# Return a BatchTransition typed dict
return BatchTransition(
state=batch_state,
action=batch_actions,
reward=batch_rewards,
next_state=batch_next_state,
done=batch_dones,
)
# def sample(self, batch_size: int):
# # 1) Randomly sample transitions
# transitions = random.sample(self.memory, batch_size)
# # 2) For each key in state_keys, gather states [b, state_dim], next_states [b, state_dim]
# batch_state = {}
# batch_next_state = {}
# for key in self.state_keys:
# batch_state[key] = torch.cat([t["state"][key] for t in transitions], dim=0).to(
# self.device
# ) # shape [b, state_dim, ...] depending on your data
# batch_next_state[key] = torch.cat([t["next_state"][key] for t in transitions], dim=0).to(
# self.device
# ) # shape [b, state_dim, ...]
# # 3) Build the other tensors
# batch_action = torch.cat([t["action"] for t in transitions], dim=0).to(
# self.device
# ) # shape [b, ...] or [b, action_dim, ...]
# batch_reward = torch.tensor(
# [t["reward"] for t in transitions], dtype=torch.float32, device=self.device
# ).unsqueeze(dim=-1) # shape [b, 1]
# batch_done = torch.tensor(
# [t["done"] for t in transitions], dtype=torch.bool, device=self.device
# ) # shape [b]
# # 4) Create the observation and next_observation dicts
# #
# # Each key is stacked along dim=1 so final shape is [b, 2, state_dim, ...]
# # - observation[key][..., 0, :] is the current state
# # - observation[key][..., 1, :] is the next state
# # - next_observation[key] duplicates the next state to shape [b, 2, ...]
# observation = {}
# for key in self.state_keys:
# observation[key] = torch.stack([batch_state[key], batch_next_state[key]], dim=1)
# # 5) Return your structure
# ret = observation | {"action": batch_action, "next.reward": batch_reward, "next.done": batch_done}
# return ret
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None:
raise NotImplementedError()
if job_name is None:
raise NotImplementedError()
init_logging()
logging.info(pformat(OmegaConf.to_container(cfg)))
if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
# Create an env dedicated to online episodes collection from policy rollout.
# online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
# NOTE: Off policy algorithm are efficient enought to use a single environment
logging.info("make_env online")
online_env = make_env(cfg, n_envs=1)
if cfg.training.eval_freq > 0:
logging.info("make_env eval")
eval_env = make_env(cfg, n_envs=1)
# TODO: Add a way to resume training
# log metrics to terminal and wandb
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
set_global_seed(cfg.seed)
# Check device is available
device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("make_policy")
# TODO: At some point we should just need make sac policy
policy: SACPolicy = make_policy(
hydra_cfg=cfg,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
step = 0 # number of policy updates (forward + backward + optim)
# TODO: Handle resume
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}")
# TODO: Handle offline steps
# logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
logging.info(f"{cfg.training.online_steps=}")
# logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
# logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
obs, info = online_env.reset()
obs = preprocess_observation(obs)
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
replay_buffer = ReplayBuffer(
capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys()
)
# NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0
for interaction_step in range(cfg.training.online_steps):
# NOTE: At some point we should use a wrapper to handle the observation
if interaction_step >= cfg.training.online_step_before_learning:
with torch.inference_mode():
action = policy.select_action(batch=obs)
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
else:
action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action)
# HACK
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
next_obs = preprocess_observation(next_obs)
next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs}
sum_reward_episode += float(reward[0])
# Because we are using a single environment
# we can safely assume that the episode is done
if done[0] or truncated[0]:
logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}")
logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step)
sum_reward_episode = 0
replay_buffer.add(
state=obs,
action=action,
reward=float(reward[0]),
next_state=next_obs,
done=done[0],
)
obs = next_obs
if interaction_step >= cfg.training.online_step_before_learning:
batch = replay_buffer.sample(cfg.training.batch_size)
# 'observation.state', 'action', 'next.reward', 'next.done'
# TODO: (azouitine) interface to refine
# TODO: At some point we should find a way to normalize the inputs
# batch = policy.normalize_inputs(batch)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
optimizers["critic"].step()
training_infos = {}
training_infos["loss_critic"] = loss_critic.item()
if interaction_step % cfg.training.policy_update_freq == 0:
# TD3 Trick
for _ in range(cfg.training.policy_update_freq):
loss_actor = policy.compute_loss_actor(observations=observations)
optimizers["actor"].zero_grad()
loss_actor.backward()
optimizers["actor"].step()
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(observations=observations)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
training_infos["loss_temperature"] = loss_temperature.item()
if interaction_step % cfg.training.log_freq == 0:
logger.log_dict(training_infos, interaction_step, mode="train")
policy.update_target_networks()
def clip_grad_norm(loss, clip_grad_norm_value, parameters):
grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=parameters,
max_norm=clip_grad_norm_value,
error_if_nonfinite=False,
)
return grad_norm
def update_policy(
policy,
batch,
optimizer,
grad_clip_norm,
grad_scaler: GradScaler,
lr_scheduler=None,
use_amp: bool = False,
lock=None,
):
"""Returns a dictionary of items for logging."""
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = output_dict["loss"]
grad_scaler.scale(loss).backward()
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
with lock if lock is not None else nullcontext():
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
optimizer.zero_grad()
if lr_scheduler is not None:
lr_scheduler.step()
if isinstance(policy, PolicyWithUpdate):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"],
"update_s": time.perf_counter() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"},
}
info.update({k: v for k, v in output_dict.items() if k not in info})
return info
def log_train_info(logger: Logger, info, step, cfg, dataset, is_online):
loss = info["loss"]
grad_norm = info["grad_norm"]
lr = info["lr"]
update_s = info["update_s"]
dataloading_s = info["dataloading_s"]
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.training.batch_size
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_frames
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
f"smpl:{format_big_number(num_samples)}",
# number of episodes seen during training
f"ep:{format_big_number(num_episodes)}",
# number of time all unique samples are seen
f"epch:{num_epochs:.2f}",
f"loss:{loss:.3f}",
f"grdn:{grad_norm:.3f}",
f"lr:{lr:0.1e}",
# in seconds
f"updt_s:{update_s:.3f}",
f"data_s:{dataloading_s:.3f}", # if not ~0, you are bottlenecked by cpu or io
]
logging.info(" ".join(log_items))
info["step"] = step
info["num_samples"] = num_samples
info["num_episodes"] = num_episodes
info["num_epochs"] = num_epochs
info["is_online"] = is_online
logger.log_dict(info, step, mode="train")
def log_eval_info(logger, info, step, cfg, dataset, is_online):
eval_s = info["eval_s"]
avg_sum_reward = info["avg_sum_reward"]
pc_success = info["pc_success"]
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.training.batch_size
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_frames
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
f"smpl:{format_big_number(num_samples)}",
# number of episodes seen during training
f"ep:{format_big_number(num_episodes)}",
# number of time all unique samples are seen
f"epch:{num_epochs:.2f}",
f"∑rwrd:{avg_sum_reward:.3f}",
f"success:{pc_success:.1f}%",
f"eval_s:{eval_s:.3f}",
]
logging.info(" ".join(log_items))
info["step"] = step
info["num_samples"] = num_samples
info["num_episodes"] = num_episodes
info["num_epochs"] = num_epochs
info["is_online"] = is_online
logger.log_dict(info, step, mode="eval")
# def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
# if out_dir is None:
# raise NotImplementedError()
# if job_name is None:
# raise NotImplementedError()
# init_logging()
# logging.info(pformat(OmegaConf.to_container(cfg)))
# if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
# raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
# # Create an env dedicated to online episodes collection from policy rollout.
# online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
# if cfg.training.eval_freq > 0:
# logging.info("make_env")
# eval_env = make_env(cfg)
# # If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
# # to check for any differences between the provided config and the checkpoint's config.
# if cfg.resume:
# if not Logger.get_last_checkpoint_dir(out_dir).exists():
# raise RuntimeError(
# "You have set resume=True, but there is no model checkpoint in "
# f"{Logger.get_last_checkpoint_dir(out_dir)}"
# )
# checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
# logging.info(
# colored(
# "You have set resume=True, indicating that you wish to resume a run",
# color="yellow",
# attrs=["bold"],
# )
# )
# # Get the configuration file from the last checkpoint.
# checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
# # Check for differences between the checkpoint configuration and provided configuration.
# # Hack to resolve the delta_timestamps ahead of time in order to properly diff.
# resolve_delta_timestamps(cfg)
# diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
# # Ignore the `resume` and parameters.
# if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
# del diff["values_changed"]["root['resume']"]
# # Log a warning about differences between the checkpoint configuration and the provided
# # configuration.
# if len(diff) > 0:
# logging.warning(
# "At least one difference was detected between the checkpoint configuration and "
# f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
# "takes precedence.",
# )
# # Use the checkpoint config instead of the provided config (but keep `resume` parameter).
# cfg = checkpoint_cfg
# cfg.resume = True
# elif Logger.get_last_checkpoint_dir(out_dir).exists():
# raise RuntimeError(
# f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If "
# "you meant to resume training, please use `resume=true` in your command or yaml configuration."
# )
# if cfg.eval.batch_size > cfg.eval.n_episodes:
# raise ValueError(
# "The eval batch size is greater than the number of eval episodes "
# f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} "
# f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. "
# "This might significantly slow down evaluation. To fix this, you should update your command "
# f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), "
# f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)."
# )
# # log metrics to terminal and wandb
# logger = Logger(cfg, out_dir, wandb_job_name=job_name)
# set_global_seed(cfg.seed)
# # Check device is available
# device = get_safe_torch_device(cfg.device, log=True)
# torch.backends.cudnn.benchmark = True
# torch.backends.cuda.matmul.allow_tf32 = True
# logging.info("make_dataset")
# # offline_dataset = make_dataset(cfg)
# # TODO (michel-aractingi): temporary fix to avoid datasets with task_index key that doesn't exist in online environment
# # i.e., pusht
# # if "task_index" in offline_dataset.hf_dataset[0]:
# # offline_dataset.hf_dataset = offline_dataset.hf_dataset.remove_columns(["task_index"])
# # if isinstance(offline_dataset, MultiLeRobotDataset):
# # logging.info(
# # "Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
# # f"{pformat(offline_dataset.repo_id_to_index , indent=2)}"
# # )
# # Create environment used for evaluating checkpoints during training on simulation data.
# # On real-world data, no need to create an environment as evaluations are done outside train.py,
# # using the eval.py instead, with gym_dora environment and dora-rs.
# eval_env = None
# if cfg.training.eval_freq > 0:
# logging.info("make_env")
# eval_env = make_env(cfg)
# logging.info("make_policy")
# policy = make_policy(
# hydra_cfg=cfg,
# # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# # Hack: But if we do online traning, we do not need dataset_stats
# dataset_stats=None,
# pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
# )
# assert isinstance(policy, nn.Module)
# # Create optimizer and scheduler
# # Temporary hack to move optimizer out of policy
# optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
# grad_scaler = GradScaler(enabled=cfg.use_amp)
# step = 0 # number of policy updates (forward + backward + optim)
# if cfg.resume:
# step = logger.load_last_training_state(optimizer, lr_scheduler)
# num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
# num_total_params = sum(p.numel() for p in policy.parameters())
# log_output_dir(out_dir)
# logging.info(f"{cfg.env.task=}")
# logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
# logging.info(f"{cfg.training.online_steps=}")
# # logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
# # logging.info(f"{offline_dataset.num_episodes=}")
# logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
# logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# # Note: this helper will be used in offline and online training loops.
# def evaluate_and_checkpoint_if_needed(step, is_online):
# _num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
# step_identifier = f"{step:0{_num_digits}d}"
# if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
# logging.info(f"Eval policy at step {step}")
# with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
# assert eval_env is not None
# eval_info = eval_policy(
# eval_env,
# policy,
# cfg.eval.n_episodes,
# videos_dir=Path(out_dir) / "eval" / f"videos_step_{step_identifier}",
# max_episodes_rendered=4,
# start_seed=cfg.seed,
# )
# # log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_online=is_online)
# log_eval_info(logger, eval_info["aggregated"], step, cfg, online_dataset, is_online=is_online)
# if cfg.wandb.enable:
# logger.log_video(eval_info["video_paths"][0], step, mode="eval")
# logging.info("Resume training")
# if cfg.training.save_checkpoint and (
# step % cfg.training.save_freq == 0
# or step == cfg.training.offline_steps + cfg.training.online_steps
# ):
# logging.info(f"Checkpoint policy after step {step}")
# # Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# # needed (choose 6 as a minimum for consistency without being overkill).
# logger.save_checkpoint(
# step,
# policy,
# optimizer,
# lr_scheduler,
# identifier=step_identifier,
# )
# logging.info("Resume training")
# # create dataloader for offline training
# # if cfg.training.get("drop_n_last_frames"):
# # shuffle = False
# # sampler = EpisodeAwareSampler(
# # offline_dataset.episode_data_index,
# # drop_n_last_frames=cfg.training.drop_n_last_frames,
# # shuffle=True,
# # )
# # else:
# # shuffle = True
# # sampler = None
# # dataloader = torch.utils.data.DataLoader(
# # offline_dataset,
# # num_workers=cfg.training.num_workers,
# # batch_size=cfg.training.batch_size,
# # shuffle=shuffle,
# # sampler=sampler,
# # pin_memory=device.type != "cpu",
# # drop_last=False,
# # )
# # dl_iter = cycle(dataloader)
# policy.train()
# # offline_step = 0
# # for _ in range(step, cfg.training.offline_steps):
# # if offline_step == 0:
# # logging.info("Start offline training on a fixed dataset")
# # start_time = time.perf_counter()
# # batch = next(dl_iter)
# # dataloading_s = time.perf_counter() - start_time
# # for key in batch:
# # batch[key] = batch[key].to(device, non_blocking=True)
# # train_info = update_policy(
# # policy,
# # batch,
# # optimizer,
# # cfg.training.grad_clip_norm,
# # grad_scaler=grad_scaler,
# # lr_scheduler=lr_scheduler,
# # use_amp=cfg.use_amp,
# # )
# # train_info["dataloading_s"] = dataloading_s
# # if step % cfg.training.log_freq == 0:
# # log_train_info(logger, train_info, step, cfg, offline_dataset, is_online=False)
# # # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# # # so we pass in step + 1.
# # evaluate_and_checkpoint_if_needed(step + 1, is_online=False)
# # step += 1
# # offline_step += 1 # noqa: SIM113
# # if cfg.training.online_steps == 0:
# # if eval_env:
# # eval_env.close()
# # logging.info("End of training")
# # return
# # Online training.
# # Create an env dedicated to online episodes collection from policy rollout.
# online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
# resolve_delta_timestamps(cfg)
# online_buffer_path = logger.log_dir / "online_buffer"
# if cfg.resume and not online_buffer_path.exists():
# # If we are resuming a run, we default to the data shapes and buffer capacity from the saved online
# # buffer.
# logging.warning(
# "When online training is resumed, we load the latest online buffer from the prior run, "
# "and this might not coincide with the state of the buffer as it was at the moment the checkpoint "
# "was made. This is because the online buffer is updated on disk during training, independently "
# "of our explicit checkpointing mechanisms."
# )
# online_dataset = OnlineBuffer(
# online_buffer_path,
# data_spec={
# **{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.input_shapes.items()},
# **{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.output_shapes.items()},
# "next.reward": {"shape": (), "dtype": np.dtype("float32")},
# "next.done": {"shape": (), "dtype": np.dtype("?")},
# "next.success": {"shape": (), "dtype": np.dtype("?")},
# },
# buffer_capacity=cfg.training.online_buffer_capacity,
# fps=online_env.unwrapped.metadata["render_fps"],
# delta_timestamps=cfg.training.delta_timestamps,
# )
# # If we are doing online rollouts asynchronously, deepcopy the policy to use for online rollouts (this
# # makes it possible to do online rollouts in parallel with training updates).
# online_rollout_policy = deepcopy(policy) if cfg.training.do_online_rollout_async else policy
# # Create dataloader for online training.
# # concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
# # sampler_weights = compute_sampler_weights(
# # offline_dataset,
# # offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
# # online_dataset=online_dataset,
# # # +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
# # # this final observation in the offline datasets, but we might add them in future.
# # online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
# # online_sampling_ratio=cfg.training.online_sampling_ratio,
# # )
# # sampler = torch.utils.data.WeightedRandomSampler(
# # sampler_weights,
# # num_samples=len(concat_dataset),
# # replacement=True,
# # )
# # dataloader = torch.utils.data.DataLoader(
# # concat_dataset,
# # batch_size=cfg.training.batch_size,
# # num_workers=cfg.training.num_workers,
# # sampler=sampler,
# # pin_memory=device.type != "cpu",
# # drop_last=True,
# # )
# dataloader = torch.utils.data.DataLoader(
# online_dataset,
# batch_size=cfg.training.batch_size,
# # num_workers=cfg.training.num_workers,
# num_workers=0,
# # sampler=sampler,
# pin_memory=device.type != "cpu",
# drop_last=True,
# )
# dl_iter = cycle(dataloader)
# # Lock and thread pool executor for asynchronous online rollouts. When asynchronous mode is disabled,
# # these are still used but effectively do nothing.
# # Hack: Comment the lock
# # lock = Lock()
# # Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
# # parallelization of rollouts is handled within the job.
# # Hack: ThreadPoolExecutor
# # executor = ThreadPoolExecutor(max_workers=1)
# online_step = 0
# online_rollout_s = 0 # time take to do online rollout
# update_online_buffer_s = 0 # time taken to update the online buffer with the online rollout data
# # Time taken waiting for the online buffer to finish being updated. This is relevant when using the async
# # online rollout option.
# await_update_online_buffer_s = 0
# rollout_start_seed = cfg.training.online_env_seed
# while True:
# if online_step == cfg.training.online_steps:
# break
# if online_step == 0:
# logging.info("Start online training by interacting with environment")
# def sample_trajectory_and_update_buffer():
# nonlocal rollout_start_seed
# # with lock:
# online_rollout_policy.load_state_dict(policy.state_dict())
# online_rollout_policy.eval()
# start_rollout_time = time.perf_counter()
# with torch.no_grad():
# eval_info = eval_policy(
# online_env,
# online_rollout_policy,
# n_episodes=cfg.training.online_rollout_n_episodes,
# max_episodes_rendered=min(10, cfg.training.online_rollout_n_episodes),
# videos_dir=logger.log_dir / "online_rollout_videos",
# return_episode_data=True,
# start_seed=(
# rollout_start_seed := (rollout_start_seed + cfg.training.batch_size) % 1000000
# ),
# )
# online_rollout_s = time.perf_counter() - start_rollout_time
# # with lock:
# start_update_buffer_time = time.perf_counter()
# online_dataset.add_data(eval_info["episodes"])
# # Update the concatenated dataset length used during sampling.
# # concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
# # HACK: We do only online training, so we don't need update dataset length because
# # we do not concatenate offline and online datasets.
# # online_dataset.cumulative_sizes = online_dataset.cumsum(online_dataset.datasets)
# # Update the sampling weights.
# # sampler.weights = compute_sampler_weights(
# # offline_dataset,
# # offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
# # online_dataset=online_dataset,
# # # +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
# # # this final observation in the offline datasets, but we might add them in future.
# # online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
# # online_sampling_ratio=cfg.training.online_sampling_ratio,
# # )
# # sampler.num_frames = len(concat_dataset)
# update_online_buffer_s = time.perf_counter() - start_update_buffer_time
# return online_rollout_s, update_online_buffer_s
# # Hack:Comment it
# # future = executor.submit(sample_trajectory_and_update_buffer)
# # sample_trajectory_and_update_buffer()
# # If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
# # here until the rollout and buffer update is done, before proceeding to the policy update steps.
# if (
# not cfg.training.do_online_rollout_async
# or len(online_dataset) <= cfg.training.online_buffer_seed_size
# ):
# # online_rollout_s, update_online_buffer_s = future.result()
# online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()
# if len(online_dataset) <= cfg.training.online_buffer_seed_size:
# logging.info(
# f"Seeding online buffer: {len(online_dataset)}/{cfg.training.online_buffer_seed_size}"
# )
# continue
# policy.train()
# for _ in range(cfg.training.online_steps_between_rollouts):
# # Hack: Comment the lock and reindent
# # with lock:
# start_time = time.perf_counter()
# batch = next(dl_iter)
# dataloading_s = time.perf_counter() - start_time
# for key in batch:
# batch[key] = batch[key].to(cfg.device, non_blocking=True)
# train_info = update_policy(
# policy,
# batch,
# optimizer,
# cfg.training.grad_clip_norm,
# grad_scaler=grad_scaler,
# lr_scheduler=lr_scheduler,
# use_amp=cfg.use_amp,
# # lock=lock,
# # Hack: Comment the lock
# lock=None,
# )
# train_info["dataloading_s"] = dataloading_s
# train_info["online_rollout_s"] = online_rollout_s
# train_info["update_online_buffer_s"] = update_online_buffer_s
# train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
# # Hack: Comment the lock and reindent
# # with lock:
# train_info["online_buffer_size"] = len(online_dataset)
# if step % cfg.training.log_freq == 0:
# log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True)
# # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# # so we pass in step + 1.
# evaluate_and_checkpoint_if_needed(step + 1, is_online=True)
# step += 1
# online_step += 1
# # If we're doing async rollouts, we should now wait until we've completed them before proceeding
# # to do the next batch of rollouts.
# # Hack: comment it
# # if future.running():
# start = time.perf_counter()
# # online_rollout_s, update_online_buffer_s = future.result()
# online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()
# await_update_online_buffer_s = time.perf_counter() - start
# if online_step >= cfg.training.online_steps:
# break
# if eval_env:
# eval_env.close()
# logging.info("End of training")
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=config_path)
cfg = compose(config_name=config_name)
train(cfg, out_dir=out_dir, job_name=job_name)
if __name__ == "__main__":
train_cli()