trying to get sac running

This commit is contained in:
KeWang1017 2024-12-26 23:38:46 +00:00 committed by AdilZouitine
parent 80b86e9bc3
commit a113daa81e
3 changed files with 149 additions and 40 deletions

View File

@ -20,6 +20,24 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class SACConfig: class SACConfig:
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"observation.image": [3, 84, 84],
"observation.state": [4],
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)
discount = 0.99 discount = 0.99
temperature_init = 1.0 temperature_init = 1.0
num_critics = 2 num_critics = 2
@ -29,6 +47,9 @@ class SACConfig:
temperature_lr = 3e-4 temperature_lr = 3e-4
critic_target_update_weight = 0.005 critic_target_update_weight = 0.005
utd_ratio = 2 utd_ratio = 2
state_encoder_hidden_dim = 256
latent_dim = 50
target_entropy = None
critic_network_kwargs = { critic_network_kwargs = {
"hidden_dims": [256, 256], "hidden_dims": [256, 256],
"activate_final": True, "activate_final": True,

View File

@ -40,6 +40,8 @@ class SACPolicy(
repo_url="https://github.com/huggingface/lerobot", repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "RL", "SAC"], tags=["robotics", "RL", "SAC"],
): ):
name = "sac"
def __init__( def __init__(
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
): ):
@ -71,7 +73,7 @@ class SACPolicy(
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics) self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
self.critic_target = deepcopy(self.critic_ensemble) self.critic_target = deepcopy(self.critic_ensemble)
self.actor_network = Policy( self.actor = Policy(
encoder=encoder, encoder=encoder,
network=MLP(**config.actor_network_kwargs), network=MLP(**config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0], action_dim=config.output_shapes["action"][0],
@ -91,14 +93,14 @@ class SACPolicy(
"observation.state": deque(maxlen=1), "observation.state": deque(maxlen=1),
"action": deque(maxlen=1), "action": deque(maxlen=1),
} }
if self._use_image: if "observation.image" in self.config.input_shapes:
self._queues["observation.image"] = deque(maxlen=1) self._queues["observation.image"] = deque(maxlen=1)
if self._use_env_state: if "observation.environment_state" in self.config.input_shapes:
self._queues["observation.environment_state"] = deque(maxlen=1) self._queues["observation.environment_state"] = deque(maxlen=1)
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
actions, _ = self.actor_network(batch["observations"]) ### actions, _ = self.actor(batch['observations'])
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss. """Run the batch through the model and compute the loss.
@ -119,13 +121,12 @@ class SACPolicy(
# perform image augmentation # perform image augmentation
# reward bias # reward bias from HIL-SERL code base
# from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
# calculate critics loss # calculate critics loss
# 1- compute actions from policy # 1- compute actions from policy
action_preds, log_probs = self.actor_network(observations) action_preds, log_probs = self.actor(observations)
# 2- compute q targets # 2- compute q targets
q_targets = self.target_qs(next_observations, action_preds) q_targets = self.target_qs(next_observations, action_preds)
# subsample critics to prevent overfitting if use high UTD (update to date) # subsample critics to prevent overfitting if use high UTD (update to date)
@ -168,7 +169,8 @@ class SACPolicy(
temperature = self.temperature() temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,) # 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs = self.actor_network(observations) actions, log_probs = self.actor(observations) \
# 3- get q-value predictions # 3- get q-value predictions
with torch.no_grad(): with torch.no_grad():
q_preds = self.critic_ensemble(observations, actions, return_type="mean") q_preds = self.critic_ensemble(observations, actions, return_type="mean")
@ -209,21 +211,19 @@ class SACPolicy(
class MLP(nn.Module): class MLP(nn.Module):
def __init__( def __init__(
self, self,
config: SACConfig, hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False, activate_final: bool = False,
dropout_rate: Optional[float] = None, dropout_rate: Optional[float] = None,
): ):
super().__init__() super().__init__()
self.activate_final = config.activate_final self.activate_final = activate_final
layers = [] layers = []
for i, size in enumerate(config.network_hidden_dims): for i, size in enumerate(hidden_dims):
layers.append( layers.append(nn.Linear(hidden_dims[i-1] if i > 0 else hidden_dims[0], size))
nn.Linear(config.network_hidden_dims[i - 1] if i > 0 else config.network_hidden_dims[0], size)
)
if i + 1 < len(config.network_hidden_dims) or activate_final: if i + 1 < len(hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0: if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(size)) layers.append(nn.LayerNorm(size))
@ -255,19 +255,19 @@ class Critic(nn.Module):
self.init_final = init_final self.init_final = init_final
self.activate_final = activate_final self.activate_final = activate_final
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
# Output layer # Output layer
if init_final is not None: if init_final is not None:
if self.activate_final: self.output_layer = nn.Linear(out_features, 1)
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
else:
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
nn.init.uniform_(self.output_layer.weight, -init_final, init_final) nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final) nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else: else:
if self.activate_final: self.output_layer = nn.Linear(out_features, 1)
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
else:
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
orthogonal_init()(self.output_layer.weight) orthogonal_init()(self.output_layer.weight)
self.to(self.device) self.to(self.device)
@ -329,11 +329,14 @@ class Policy(nn.Module):
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.activate_final = activate_final self.activate_final = activate_final
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
# Mean layer # Mean layer
if self.activate_final: self.mean_layer = nn.Linear(out_features, action_dim)
self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim)
else:
self.mean_layer = nn.Linear(network.net[-2].out_features, action_dim)
if init_final is not None: if init_final is not None:
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final) nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final) nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
@ -345,10 +348,7 @@ class Policy(nn.Module):
if std_parameterization == "uniform": if std_parameterization == "uniform":
self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device)) self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device))
else: else:
if self.activate_final: self.std_layer = nn.Linear(out_features, action_dim)
self.std_layer = nn.Linear(network.net[-3].out_features, action_dim)
else:
self.std_layer = nn.Linear(network.net[-2].out_features, action_dim)
if init_final is not None: if init_final is not None:
nn.init.uniform_(self.std_layer.weight, -init_final, init_final) nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
nn.init.uniform_(self.std_layer.bias, -init_final, init_final) nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
@ -571,7 +571,6 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
"""Get the mode of the transformed distribution""" """Get the mode of the transformed distribution"""
# The mode of a normal distribution is its mean # The mode of a normal distribution is its mean
mode = self.loc mode = self.loc
# Apply transforms # Apply transforms
for transform in self.transforms: for transform in self.transforms:
mode = transform(mode) mode = transform(mode)
@ -634,10 +633,10 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
return entropy return entropy
def create_critic_ensemble(critic_class, num_critics: int, device: str = "cuda") -> nn.ModuleList: def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList:
"""Creates an ensemble of critic networks""" """Creates an ensemble of critic networks"""
critics = nn.ModuleList([critic_class() for _ in range(num_critics)]) assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
return critics.to(device) return nn.ModuleList(critics).to(device)
def orthogonal_init(): def orthogonal_init():

View File

@ -0,0 +1,89 @@
# @package _global_
# Train with:
#
# python lerobot/scripts/train.py \
# env=pusht \
# +dataset=lerobot/pusht_keypoints
seed: 1
dataset_repo_id: lerobot/pusht_keypoints
training:
offline_steps: 0
# Offline training dataloader
num_workers: 4
batch_size: 128
grad_clip_norm: 10.0
lr: 3e-4
eval_freq: 10000
log_freq: 500
save_freq: 50000
online_steps: 1000000
online_rollout_n_episodes: 10
online_rollout_batch_size: 10
online_steps_between_rollouts: 1000
online_sampling_ratio: 1.0
online_env_seed: 10000
online_buffer_capacity: 40000
online_buffer_seed_size: 0
do_online_rollout_async: false
delta_timestamps:
observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
action: "[i / ${fps} for i in range(${policy.horizon})]"
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
policy:
name: sac
pretrained_model_path:
# Input / output structure.
n_action_repeats: 1
horizon: 5
n_action_steps: 5
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.environment_state: [16]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.environment_state: min_max
observation.state: min_max
output_normalization_modes:
action: min_max
# Architecture / modeling.
# Neural networks.
# image_encoder_hidden_dim: 32
discount: 0.99
temperature_init: 1.0
num_critics: 2
num_subsample_critics: None
critic_lr: 3e-4
actor_lr: 3e-4
temperature_lr: 3e-4
critic_target_update_weight: 0.005
utd_ratio: 2
# # Loss coefficients.
# reward_coeff: 0.5
# expectile_weight: 0.9
# value_coeff: 0.1
# consistency_coeff: 20.0
# advantage_scaling: 3.0
# pi_coeff: 0.5
# temporal_decay_coeff: 0.5
# # Target model.
# target_model_momentum: 0.995