trying to get sac running
This commit is contained in:
parent
80b86e9bc3
commit
a113daa81e
|
@ -20,6 +20,24 @@ from dataclasses import dataclass, field
|
|||
|
||||
@dataclass
|
||||
class SACConfig:
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 84, 84],
|
||||
"observation.state": [4],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [4],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] | None = None
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"},
|
||||
)
|
||||
|
||||
discount = 0.99
|
||||
temperature_init = 1.0
|
||||
num_critics = 2
|
||||
|
@ -29,6 +47,9 @@ class SACConfig:
|
|||
temperature_lr = 3e-4
|
||||
critic_target_update_weight = 0.005
|
||||
utd_ratio = 2
|
||||
state_encoder_hidden_dim = 256
|
||||
latent_dim = 50
|
||||
target_entropy = None
|
||||
critic_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
|
|
|
@ -40,6 +40,8 @@ class SACPolicy(
|
|||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "RL", "SAC"],
|
||||
):
|
||||
name = "sac"
|
||||
|
||||
def __init__(
|
||||
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
||||
):
|
||||
|
@ -71,7 +73,7 @@ class SACPolicy(
|
|||
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
||||
self.critic_target = deepcopy(self.critic_ensemble)
|
||||
|
||||
self.actor_network = Policy(
|
||||
self.actor = Policy(
|
||||
encoder=encoder,
|
||||
network=MLP(**config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
|
@ -91,14 +93,14 @@ class SACPolicy(
|
|||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=1),
|
||||
}
|
||||
if self._use_image:
|
||||
if "observation.image" in self.config.input_shapes:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
if self._use_env_state:
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
actions, _ = self.actor_network(batch["observations"]) ###
|
||||
actions, _ = self.actor(batch['observations'])
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
@ -119,19 +121,18 @@ class SACPolicy(
|
|||
|
||||
# perform image augmentation
|
||||
|
||||
# reward bias
|
||||
# from HIL-SERL code base
|
||||
# reward bias from HIL-SERL code base
|
||||
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
||||
|
||||
# calculate critics loss
|
||||
# 1- compute actions from policy
|
||||
action_preds, log_probs = self.actor_network(observations)
|
||||
action_preds, log_probs = self.actor(observations)
|
||||
# 2- compute q targets
|
||||
q_targets = self.target_qs(next_observations, action_preds)
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
indices = indices[:self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
|
@ -168,7 +169,8 @@ class SACPolicy(
|
|||
temperature = self.temperature()
|
||||
|
||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||
actions, log_probs = self.actor_network(observations)
|
||||
actions, log_probs = self.actor(observations) \
|
||||
|
||||
# 3- get q-value predictions
|
||||
with torch.no_grad():
|
||||
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
|
||||
|
@ -209,21 +211,19 @@ class SACPolicy(
|
|||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig,
|
||||
hidden_dims: list[int],
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.activate_final = config.activate_final
|
||||
self.activate_final = activate_final
|
||||
layers = []
|
||||
|
||||
for i, size in enumerate(config.network_hidden_dims):
|
||||
layers.append(
|
||||
nn.Linear(config.network_hidden_dims[i - 1] if i > 0 else config.network_hidden_dims[0], size)
|
||||
)
|
||||
|
||||
if i + 1 < len(config.network_hidden_dims) or activate_final:
|
||||
|
||||
for i, size in enumerate(hidden_dims):
|
||||
layers.append(nn.Linear(hidden_dims[i-1] if i > 0 else hidden_dims[0], size))
|
||||
|
||||
if i + 1 < len(hidden_dims) or activate_final:
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(size))
|
||||
|
@ -254,20 +254,20 @@ class Critic(nn.Module):
|
|||
self.network = network
|
||||
self.init_final = init_final
|
||||
self.activate_final = activate_final
|
||||
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
# Output layer
|
||||
if init_final is not None:
|
||||
if self.activate_final:
|
||||
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
|
||||
else:
|
||||
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
if self.activate_final:
|
||||
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
|
||||
else:
|
||||
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
self.to(self.device)
|
||||
|
@ -328,12 +328,15 @@ class Policy(nn.Module):
|
|||
self.tanh_squash_distribution = tanh_squash_distribution
|
||||
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
||||
self.activate_final = activate_final
|
||||
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
# Mean layer
|
||||
if self.activate_final:
|
||||
self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim)
|
||||
else:
|
||||
self.mean_layer = nn.Linear(network.net[-2].out_features, action_dim)
|
||||
self.mean_layer = nn.Linear(out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
|
||||
|
@ -345,10 +348,7 @@ class Policy(nn.Module):
|
|||
if std_parameterization == "uniform":
|
||||
self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device))
|
||||
else:
|
||||
if self.activate_final:
|
||||
self.std_layer = nn.Linear(network.net[-3].out_features, action_dim)
|
||||
else:
|
||||
self.std_layer = nn.Linear(network.net[-2].out_features, action_dim)
|
||||
self.std_layer = nn.Linear(out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
|
@ -571,7 +571,6 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
|||
"""Get the mode of the transformed distribution"""
|
||||
# The mode of a normal distribution is its mean
|
||||
mode = self.loc
|
||||
|
||||
# Apply transforms
|
||||
for transform in self.transforms:
|
||||
mode = transform(mode)
|
||||
|
@ -634,10 +633,10 @@ class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
|||
return entropy
|
||||
|
||||
|
||||
def create_critic_ensemble(critic_class, num_critics: int, device: str = "cuda") -> nn.ModuleList:
|
||||
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList:
|
||||
"""Creates an ensemble of critic networks"""
|
||||
critics = nn.ModuleList([critic_class() for _ in range(num_critics)])
|
||||
return critics.to(device)
|
||||
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
||||
return nn.ModuleList(critics).to(device)
|
||||
|
||||
|
||||
def orthogonal_init():
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
# @package _global_
|
||||
|
||||
# Train with:
|
||||
#
|
||||
# python lerobot/scripts/train.py \
|
||||
# env=pusht \
|
||||
# +dataset=lerobot/pusht_keypoints
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: lerobot/pusht_keypoints
|
||||
|
||||
training:
|
||||
offline_steps: 0
|
||||
|
||||
# Offline training dataloader
|
||||
num_workers: 4
|
||||
|
||||
batch_size: 128
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 10000
|
||||
log_freq: 500
|
||||
save_freq: 50000
|
||||
|
||||
online_steps: 1000000
|
||||
online_rollout_n_episodes: 10
|
||||
online_rollout_batch_size: 10
|
||||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 40000
|
||||
online_buffer_seed_size: 0
|
||||
do_online_rollout_async: false
|
||||
|
||||
delta_timestamps:
|
||||
observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
|
||||
policy:
|
||||
name: sac
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: 1
|
||||
horizon: 5
|
||||
n_action_steps: 5
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.environment_state: [16]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.environment_state: min_max
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
# image_encoder_hidden_dim: 32
|
||||
discount: 0.99
|
||||
temperature_init: 1.0
|
||||
num_critics: 2
|
||||
num_subsample_critics: None
|
||||
critic_lr: 3e-4
|
||||
actor_lr: 3e-4
|
||||
temperature_lr: 3e-4
|
||||
critic_target_update_weight: 0.005
|
||||
utd_ratio: 2
|
||||
|
||||
|
||||
# # Loss coefficients.
|
||||
# reward_coeff: 0.5
|
||||
# expectile_weight: 0.9
|
||||
# value_coeff: 0.1
|
||||
# consistency_coeff: 20.0
|
||||
# advantage_scaling: 3.0
|
||||
# pi_coeff: 0.5
|
||||
# temporal_decay_coeff: 0.5
|
||||
# # Target model.
|
||||
# target_model_momentum: 0.995
|
Loading…
Reference in New Issue