trying to get sac running
This commit is contained in:
parent
80b86e9bc3
commit
a113daa81e
|
@ -20,6 +20,24 @@ from dataclasses import dataclass, field
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SACConfig:
|
class SACConfig:
|
||||||
|
input_shapes: dict[str, list[int]] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"observation.image": [3, 84, 84],
|
||||||
|
"observation.state": [4],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
output_shapes: dict[str, list[int]] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"action": [4],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes: dict[str, str] | None = None
|
||||||
|
output_normalization_modes: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {"action": "min_max"},
|
||||||
|
)
|
||||||
|
|
||||||
discount = 0.99
|
discount = 0.99
|
||||||
temperature_init = 1.0
|
temperature_init = 1.0
|
||||||
num_critics = 2
|
num_critics = 2
|
||||||
|
@ -29,6 +47,9 @@ class SACConfig:
|
||||||
temperature_lr = 3e-4
|
temperature_lr = 3e-4
|
||||||
critic_target_update_weight = 0.005
|
critic_target_update_weight = 0.005
|
||||||
utd_ratio = 2
|
utd_ratio = 2
|
||||||
|
state_encoder_hidden_dim = 256
|
||||||
|
latent_dim = 50
|
||||||
|
target_entropy = None
|
||||||
critic_network_kwargs = {
|
critic_network_kwargs = {
|
||||||
"hidden_dims": [256, 256],
|
"hidden_dims": [256, 256],
|
||||||
"activate_final": True,
|
"activate_final": True,
|
||||||
|
|
|
@ -40,6 +40,8 @@ class SACPolicy(
|
||||||
repo_url="https://github.com/huggingface/lerobot",
|
repo_url="https://github.com/huggingface/lerobot",
|
||||||
tags=["robotics", "RL", "SAC"],
|
tags=["robotics", "RL", "SAC"],
|
||||||
):
|
):
|
||||||
|
name = "sac"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
||||||
):
|
):
|
||||||
|
@ -71,7 +73,7 @@ class SACPolicy(
|
||||||
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
||||||
self.critic_target = deepcopy(self.critic_ensemble)
|
self.critic_target = deepcopy(self.critic_ensemble)
|
||||||
|
|
||||||
self.actor_network = Policy(
|
self.actor = Policy(
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
network=MLP(**config.actor_network_kwargs),
|
network=MLP(**config.actor_network_kwargs),
|
||||||
action_dim=config.output_shapes["action"][0],
|
action_dim=config.output_shapes["action"][0],
|
||||||
|
@ -91,14 +93,14 @@ class SACPolicy(
|
||||||
"observation.state": deque(maxlen=1),
|
"observation.state": deque(maxlen=1),
|
||||||
"action": deque(maxlen=1),
|
"action": deque(maxlen=1),
|
||||||
}
|
}
|
||||||
if self._use_image:
|
if "observation.image" in self.config.input_shapes:
|
||||||
self._queues["observation.image"] = deque(maxlen=1)
|
self._queues["observation.image"] = deque(maxlen=1)
|
||||||
if self._use_env_state:
|
if "observation.environment_state" in self.config.input_shapes:
|
||||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
actions, _ = self.actor_network(batch["observations"]) ###
|
actions, _ = self.actor(batch['observations'])
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||||
"""Run the batch through the model and compute the loss.
|
"""Run the batch through the model and compute the loss.
|
||||||
|
@ -119,13 +121,12 @@ class SACPolicy(
|
||||||
|
|
||||||
# perform image augmentation
|
# perform image augmentation
|
||||||
|
|
||||||
# reward bias
|
# reward bias from HIL-SERL code base
|
||||||
# from HIL-SERL code base
|
|
||||||
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
||||||
|
|
||||||
# calculate critics loss
|
# calculate critics loss
|
||||||
# 1- compute actions from policy
|
# 1- compute actions from policy
|
||||||
action_preds, log_probs = self.actor_network(observations)
|
action_preds, log_probs = self.actor(observations)
|
||||||
# 2- compute q targets
|
# 2- compute q targets
|
||||||
q_targets = self.target_qs(next_observations, action_preds)
|
q_targets = self.target_qs(next_observations, action_preds)
|
||||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||||
|
@ -168,7 +169,8 @@ class SACPolicy(
|
||||||
temperature = self.temperature()
|
temperature = self.temperature()
|
||||||
|
|
||||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||||
actions, log_probs = self.actor_network(observations)
|
actions, log_probs = self.actor(observations) \
|
||||||
|
|
||||||
# 3- get q-value predictions
|
# 3- get q-value predictions
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
|
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
|
||||||
|
@ -209,21 +211,19 @@ class SACPolicy(
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: SACConfig,
|
hidden_dims: list[int],
|
||||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||||
activate_final: bool = False,
|
activate_final: bool = False,
|
||||||
dropout_rate: Optional[float] = None,
|
dropout_rate: Optional[float] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.activate_final = config.activate_final
|
self.activate_final = activate_final
|
||||||
layers = []
|
layers = []
|
||||||
|
|
||||||
for i, size in enumerate(config.network_hidden_dims):
|
for i, size in enumerate(hidden_dims):
|
||||||
layers.append(
|
layers.append(nn.Linear(hidden_dims[i-1] if i > 0 else hidden_dims[0], size))
|
||||||
nn.Linear(config.network_hidden_dims[i - 1] if i > 0 else config.network_hidden_dims[0], size)
|
|
||||||
)
|
|
||||||
|
|
||||||
if i + 1 < len(config.network_hidden_dims) or activate_final:
|
if i + 1 < len(hidden_dims) or activate_final:
|
||||||
if dropout_rate is not None and dropout_rate > 0:
|
if dropout_rate is not None and dropout_rate > 0:
|
||||||
layers.append(nn.Dropout(p=dropout_rate))
|
layers.append(nn.Dropout(p=dropout_rate))
|
||||||
layers.append(nn.LayerNorm(size))
|
layers.append(nn.LayerNorm(size))
|
||||||
|
@ -255,19 +255,19 @@ class Critic(nn.Module):
|
||||||
self.init_final = init_final
|
self.init_final = init_final
|
||||||
self.activate_final = activate_final
|
self.activate_final = activate_final
|
||||||
|
|
||||||
|
# Find the last Linear layer's output dimension
|
||||||
|
for layer in reversed(network.net):
|
||||||
|
if isinstance(layer, nn.Linear):
|
||||||
|
out_features = layer.out_features
|
||||||
|
break
|
||||||
|
|
||||||
# Output layer
|
# Output layer
|
||||||
if init_final is not None:
|
if init_final is not None:
|
||||||
if self.activate_final:
|
self.output_layer = nn.Linear(out_features, 1)
|
||||||
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
|
|
||||||
else:
|
|
||||||
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
|
|
||||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||||
else:
|
else:
|
||||||
if self.activate_final:
|
self.output_layer = nn.Linear(out_features, 1)
|
||||||
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
|
|
||||||
else:
|
|
||||||
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
|
|
||||||
orthogonal_init()(self.output_layer.weight)
|
orthogonal_init()(self.output_layer.weight)
|
||||||
|
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
@ -329,11 +329,14 @@ class Policy(nn.Module):
|
||||||
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
||||||
self.activate_final = activate_final
|
self.activate_final = activate_final
|
||||||
|
|
||||||
|
# Find the last Linear layer's output dimension
|
||||||
|
for layer in reversed(network.net):
|
||||||
|
if isinstance(layer, nn.Linear):
|
||||||
|
out_features = layer.out_features
|
||||||
|
break
|
||||||
|
|
||||||
# Mean layer
|
# Mean layer
|
||||||
if self.activate_final:
|
self.mean_layer = nn.Linear(out_features, action_dim)
|
||||||
self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim)
|
|
||||||
else:
|
|
||||||
self.mean_layer = nn.Linear(network.net[-2].out_features, action_dim)
|
|
||||||
if init_final is not None:
|
if init_final is not None:
|
||||||
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
|
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
|
||||||
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
|
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
|
||||||
|
@ -345,10 +348,7 @@ class Policy(nn.Module):
|
||||||
if std_parameterization == "uniform":
|
if std_parameterization == "uniform":
|
||||||
self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device))
|
self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device))
|
||||||
else:
|
else:
|
||||||
if self.activate_final:
|
self.std_layer = nn.Linear(out_features, action_dim)
|
||||||
self.std_layer = nn.Linear(network.net[-3].out_features, action_dim)
|
|
||||||
else:
|
|
||||||
self.std_layer = nn.Linear(network.net[-2].out_features, action_dim)
|
|
||||||
if init_final is not None:
|
if init_final is not None:
|
||||||
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
|
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
|
||||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||||
|
@ -571,7 +571,6 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
||||||
"""Get the mode of the transformed distribution"""
|
"""Get the mode of the transformed distribution"""
|
||||||
# The mode of a normal distribution is its mean
|
# The mode of a normal distribution is its mean
|
||||||
mode = self.loc
|
mode = self.loc
|
||||||
|
|
||||||
# Apply transforms
|
# Apply transforms
|
||||||
for transform in self.transforms:
|
for transform in self.transforms:
|
||||||
mode = transform(mode)
|
mode = transform(mode)
|
||||||
|
@ -634,10 +633,10 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
||||||
return entropy
|
return entropy
|
||||||
|
|
||||||
|
|
||||||
def create_critic_ensemble(critic_class, num_critics: int, device: str = "cuda") -> nn.ModuleList:
|
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList:
|
||||||
"""Creates an ensemble of critic networks"""
|
"""Creates an ensemble of critic networks"""
|
||||||
critics = nn.ModuleList([critic_class() for _ in range(num_critics)])
|
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
||||||
return critics.to(device)
|
return nn.ModuleList(critics).to(device)
|
||||||
|
|
||||||
|
|
||||||
def orthogonal_init():
|
def orthogonal_init():
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# Train with:
|
||||||
|
#
|
||||||
|
# python lerobot/scripts/train.py \
|
||||||
|
# env=pusht \
|
||||||
|
# +dataset=lerobot/pusht_keypoints
|
||||||
|
|
||||||
|
seed: 1
|
||||||
|
dataset_repo_id: lerobot/pusht_keypoints
|
||||||
|
|
||||||
|
training:
|
||||||
|
offline_steps: 0
|
||||||
|
|
||||||
|
# Offline training dataloader
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
batch_size: 128
|
||||||
|
grad_clip_norm: 10.0
|
||||||
|
lr: 3e-4
|
||||||
|
|
||||||
|
eval_freq: 10000
|
||||||
|
log_freq: 500
|
||||||
|
save_freq: 50000
|
||||||
|
|
||||||
|
online_steps: 1000000
|
||||||
|
online_rollout_n_episodes: 10
|
||||||
|
online_rollout_batch_size: 10
|
||||||
|
online_steps_between_rollouts: 1000
|
||||||
|
online_sampling_ratio: 1.0
|
||||||
|
online_env_seed: 10000
|
||||||
|
online_buffer_capacity: 40000
|
||||||
|
online_buffer_seed_size: 0
|
||||||
|
do_online_rollout_async: false
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||||
|
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||||
|
action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||||
|
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||||
|
|
||||||
|
policy:
|
||||||
|
name: sac
|
||||||
|
|
||||||
|
pretrained_model_path:
|
||||||
|
|
||||||
|
# Input / output structure.
|
||||||
|
n_action_repeats: 1
|
||||||
|
horizon: 5
|
||||||
|
n_action_steps: 5
|
||||||
|
|
||||||
|
input_shapes:
|
||||||
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
observation.environment_state: [16]
|
||||||
|
observation.state: ["${env.state_dim}"]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes:
|
||||||
|
observation.environment_state: min_max
|
||||||
|
observation.state: min_max
|
||||||
|
output_normalization_modes:
|
||||||
|
action: min_max
|
||||||
|
|
||||||
|
# Architecture / modeling.
|
||||||
|
# Neural networks.
|
||||||
|
# image_encoder_hidden_dim: 32
|
||||||
|
discount: 0.99
|
||||||
|
temperature_init: 1.0
|
||||||
|
num_critics: 2
|
||||||
|
num_subsample_critics: None
|
||||||
|
critic_lr: 3e-4
|
||||||
|
actor_lr: 3e-4
|
||||||
|
temperature_lr: 3e-4
|
||||||
|
critic_target_update_weight: 0.005
|
||||||
|
utd_ratio: 2
|
||||||
|
|
||||||
|
|
||||||
|
# # Loss coefficients.
|
||||||
|
# reward_coeff: 0.5
|
||||||
|
# expectile_weight: 0.9
|
||||||
|
# value_coeff: 0.1
|
||||||
|
# consistency_coeff: 20.0
|
||||||
|
# advantage_scaling: 3.0
|
||||||
|
# pi_coeff: 0.5
|
||||||
|
# temporal_decay_coeff: 0.5
|
||||||
|
# # Target model.
|
||||||
|
# target_model_momentum: 0.995
|
Loading…
Reference in New Issue