trying to get sac running

This commit is contained in:
KeWang1017 2024-12-26 23:38:46 +00:00 committed by Michel Aractingi
parent dc54d357ca
commit 18a4598986
3 changed files with 149 additions and 40 deletions

View File

@ -20,6 +20,24 @@ from dataclasses import dataclass, field
@dataclass
class SACConfig:
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"observation.image": [3, 84, 84],
"observation.state": [4],
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)
discount = 0.99
temperature_init = 1.0
num_critics = 2
@ -29,6 +47,9 @@ class SACConfig:
temperature_lr = 3e-4
critic_target_update_weight = 0.005
utd_ratio = 2
state_encoder_hidden_dim = 256
latent_dim = 50
target_entropy = None
critic_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,

View File

@ -40,6 +40,8 @@ class SACPolicy(
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "RL", "SAC"],
):
name = "sac"
def __init__(
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
):
@ -71,7 +73,7 @@ class SACPolicy(
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
self.critic_target = deepcopy(self.critic_ensemble)
self.actor_network = Policy(
self.actor = Policy(
encoder=encoder,
network=MLP(**config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
@ -91,14 +93,14 @@ class SACPolicy(
"observation.state": deque(maxlen=1),
"action": deque(maxlen=1),
}
if self._use_image:
if "observation.image" in self.config.input_shapes:
self._queues["observation.image"] = deque(maxlen=1)
if self._use_env_state:
if "observation.environment_state" in self.config.input_shapes:
self._queues["observation.environment_state"] = deque(maxlen=1)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
actions, _ = self.actor_network(batch["observations"]) ###
actions, _ = self.actor(batch['observations'])
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.
@ -119,19 +121,18 @@ class SACPolicy(
# perform image augmentation
# reward bias
# from HIL-SERL code base
# reward bias from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
# calculate critics loss
# 1- compute actions from policy
action_preds, log_probs = self.actor_network(observations)
action_preds, log_probs = self.actor(observations)
# 2- compute q targets
q_targets = self.target_qs(next_observations, action_preds)
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
indices = indices[:self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
@ -168,7 +169,8 @@ class SACPolicy(
temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs = self.actor_network(observations)
actions, log_probs = self.actor(observations) \
# 3- get q-value predictions
with torch.no_grad():
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
@ -209,21 +211,19 @@ class SACPolicy(
class MLP(nn.Module):
def __init__(
self,
config: SACConfig,
hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
):
super().__init__()
self.activate_final = config.activate_final
self.activate_final = activate_final
layers = []
for i, size in enumerate(config.network_hidden_dims):
layers.append(
nn.Linear(config.network_hidden_dims[i - 1] if i > 0 else config.network_hidden_dims[0], size)
)
for i, size in enumerate(hidden_dims):
layers.append(nn.Linear(hidden_dims[i-1] if i > 0 else hidden_dims[0], size))
if i + 1 < len(config.network_hidden_dims) or activate_final:
if i + 1 < len(hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(size))
@ -255,19 +255,19 @@ class Critic(nn.Module):
self.init_final = init_final
self.activate_final = activate_final
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
# Output layer
if init_final is not None:
if self.activate_final:
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
else:
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
self.output_layer = nn.Linear(out_features, 1)
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
if self.activate_final:
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
else:
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
self.output_layer = nn.Linear(out_features, 1)
orthogonal_init()(self.output_layer.weight)
self.to(self.device)
@ -329,11 +329,14 @@ class Policy(nn.Module):
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.activate_final = activate_final
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
# Mean layer
if self.activate_final:
self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim)
else:
self.mean_layer = nn.Linear(network.net[-2].out_features, action_dim)
self.mean_layer = nn.Linear(out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
@ -345,10 +348,7 @@ class Policy(nn.Module):
if std_parameterization == "uniform":
self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device))
else:
if self.activate_final:
self.std_layer = nn.Linear(network.net[-3].out_features, action_dim)
else:
self.std_layer = nn.Linear(network.net[-2].out_features, action_dim)
self.std_layer = nn.Linear(out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
@ -571,7 +571,6 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
"""Get the mode of the transformed distribution"""
# The mode of a normal distribution is its mean
mode = self.loc
# Apply transforms
for transform in self.transforms:
mode = transform(mode)
@ -634,10 +633,10 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
return entropy
def create_critic_ensemble(critic_class, num_critics: int, device: str = "cuda") -> nn.ModuleList:
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList:
"""Creates an ensemble of critic networks"""
critics = nn.ModuleList([critic_class() for _ in range(num_critics)])
return critics.to(device)
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
return nn.ModuleList(critics).to(device)
def orthogonal_init():

View File

@ -0,0 +1,89 @@
# @package _global_
# Train with:
#
# python lerobot/scripts/train.py \
# env=pusht \
# +dataset=lerobot/pusht_keypoints
seed: 1
dataset_repo_id: lerobot/pusht_keypoints
training:
offline_steps: 0
# Offline training dataloader
num_workers: 4
batch_size: 128
grad_clip_norm: 10.0
lr: 3e-4
eval_freq: 10000
log_freq: 500
save_freq: 50000
online_steps: 1000000
online_rollout_n_episodes: 10
online_rollout_batch_size: 10
online_steps_between_rollouts: 1000
online_sampling_ratio: 1.0
online_env_seed: 10000
online_buffer_capacity: 40000
online_buffer_seed_size: 0
do_online_rollout_async: false
delta_timestamps:
observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
action: "[i / ${fps} for i in range(${policy.horizon})]"
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
policy:
name: sac
pretrained_model_path:
# Input / output structure.
n_action_repeats: 1
horizon: 5
n_action_steps: 5
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.environment_state: [16]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.environment_state: min_max
observation.state: min_max
output_normalization_modes:
action: min_max
# Architecture / modeling.
# Neural networks.
# image_encoder_hidden_dim: 32
discount: 0.99
temperature_init: 1.0
num_critics: 2
num_subsample_critics: None
critic_lr: 3e-4
actor_lr: 3e-4
temperature_lr: 3e-4
critic_target_update_weight: 0.005
utd_ratio: 2
# # Loss coefficients.
# reward_coeff: 0.5
# expectile_weight: 0.9
# value_coeff: 0.1
# consistency_coeff: 20.0
# advantage_scaling: 3.0
# pi_coeff: 0.5
# temporal_decay_coeff: 0.5
# # Target model.
# target_model_momentum: 0.995