Port SAC WIP (#581)

Co-authored-by: KeWang1017 <ke.wang@helloleap.ai>
This commit is contained in:
KeWang 2024-12-17 13:26:17 +00:00 committed by GitHub
parent 972bac98b4
commit 5ae5dcb30f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 518 additions and 6 deletions

View File

@ -21,3 +21,19 @@ from dataclasses import dataclass
@dataclass @dataclass
class SACConfig: class SACConfig:
discount = 0.99 discount = 0.99
temperature_init = 1.0
num_critics = 2
critic_lr = 3e-4
actor_lr = 3e-4
critic_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,
}
actor_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,
}
policy_kwargs = {
"tanh_squash_distribution": True,
"std_parameterization": "uniform",
}

View File

@ -15,7 +15,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# TODO: (1) better device management
from collections import deque from collections import deque
from copy import deepcopy
from functools import partial
import einops import einops
@ -27,6 +31,10 @@ from torch import Tensor
from huggingface_hub import PyTorchModelHubMixin from huggingface_hub import PyTorchModelHubMixin
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.sac.configuration_sac import SACConfig
import numpy as np
from typing import Callable, Optional, Tuple, Sequence
class SACPolicy( class SACPolicy(
nn.Module, nn.Module,
@ -58,12 +66,27 @@ class SACPolicy(
self.unnormalize_outputs = Unnormalize( self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_shapes, config.output_normalization_modes, dataset_stats
) )
encoder = SACObservationEncoder(config)
# Define networks
critic_nets = []
for _ in range(config.num_critics):
critic_net = Critic(
encoder=encoder,
network=MLP(**config.critic_network_kwargs)
)
critic_nets.append(critic_net)
self.critic_ensemble = ... self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
self.critic_target = ... self.critic_target = deepcopy(self.critic_ensemble)
self.actor_network = ...
self.temperature = ... self.actor_network = Policy(
encoder=encoder,
network=MLP(**config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
**config.policy_kwargs
)
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
def reset(self): def reset(self):
""" """
@ -178,10 +201,483 @@ class SACPolicy(
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()): #for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight) # target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
class MLP(nn.Module):
def __init__(
self,
config: SACConfig,
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
):
super().__init__()
self.activate_final = config.activate_final
layers = []
for i, size in enumerate(config.network_hidden_dims):
layers.append(nn.Linear(config.network_hidden_dims[i-1] if i > 0 else config.network_hidden_dims[0], size))
if i + 1 < len(config.network_hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(size))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor:
# in training mode or not. TODO: find better way to do this
self.train(train)
return self.net(x)
class Critic(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
network: nn.Module,
init_final: Optional[float] = None,
activate_final: bool = False,
device: str = "cuda"
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network = network
self.init_final = init_final
self.activate_final = activate_final
# Output layer
if init_final is not None:
if self.activate_final:
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
else:
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
if self.activate_final:
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
else:
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
orthogonal_init()(self.output_layer.weight)
self.to(self.device)
def forward(
self,
observations: torch.Tensor,
actions: torch.Tensor,
train: bool = False
) -> torch.Tensor:
self.train(train)
observations = observations.to(self.device)
actions = actions.to(self.device)
if self.encoder is not None:
obs_enc = self.encoder(observations)
else:
obs_enc = observations
inputs = torch.cat([obs_enc, actions], dim=-1)
x = self.network(inputs)
value = self.output_layer(x)
return value.squeeze(-1)
def q_value_ensemble(
self,
observations: torch.Tensor,
actions: torch.Tensor,
train: bool = False
) -> torch.Tensor:
observations = observations.to(self.device)
actions = actions.to(self.device)
if len(actions.shape) == 3: # [batch_size, num_actions, action_dim]
batch_size, num_actions = actions.shape[:2]
obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1)
obs_flat = obs_expanded.reshape(-1, observations.shape[-1])
actions_flat = actions.reshape(-1, actions.shape[-1])
q_values = self(obs_flat, actions_flat, train)
return q_values.reshape(batch_size, num_actions)
else:
return self(observations, actions, train)
class Policy(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
network: nn.Module,
action_dim: int,
std_parameterization: str = "exp",
std_min: float = 1e-5,
std_max: float = 10.0,
tanh_squash_distribution: bool = False,
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
activate_final: bool = False,
device: str = "cuda"
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network = network
self.action_dim = action_dim
self.std_parameterization = std_parameterization
self.std_min = std_min
self.std_max = std_max
self.tanh_squash_distribution = tanh_squash_distribution
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.activate_final = activate_final
# Mean layer
if self.activate_final:
self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim)
else:
self.mean_layer = nn.Linear(network.net[-2].out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.mean_layer.weight)
# Standard deviation layer or parameter
if fixed_std is None:
if std_parameterization == "uniform":
self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device))
else:
if self.activate_final:
self.std_layer = nn.Linear(network.net[-3].out_features, action_dim)
else:
self.std_layer = nn.Linear(network.net[-2].out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.std_layer.weight)
self.to(self.device)
def forward(
self,
observations: torch.Tensor,
temperature: float = 1.0,
train: bool = False,
non_squash_distribution: bool = False
) -> torch.distributions.Distribution:
self.train(train)
# Encode observations if encoder exists
if self.encoder is not None:
with torch.set_grad_enabled(train):
obs_enc = self.encoder(observations, train=train)
else:
obs_enc = observations
# Get network outputs
outputs = self.network(obs_enc)
means = self.mean_layer(outputs)
# Compute standard deviations
if self.fixed_std is None:
if self.std_parameterization == "exp":
log_stds = self.std_layer(outputs)
stds = torch.exp(log_stds)
elif self.std_parameterization == "softplus":
stds = torch.nn.functional.softplus(self.std_layer(outputs))
elif self.std_parameterization == "uniform":
stds = torch.exp(self.log_stds).expand_as(means)
else:
raise ValueError(
f"Invalid std_parameterization: {self.std_parameterization}"
)
else:
assert self.std_parameterization == "fixed"
stds = self.fixed_std.expand_as(means)
# Clip standard deviations and scale with temperature
temperature = torch.tensor(temperature, device=self.device)
stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature)
# Create distribution
if self.tanh_squash_distribution and not non_squash_distribution:
distribution = TanhMultivariateNormalDiag(
loc=means,
scale_diag=stds,
)
else:
distribution = torch.distributions.Normal(
loc=means,
scale=stds,
)
return distribution
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations"""
observations = observations.to(self.device)
if self.encoder is not None:
with torch.no_grad():
return self.encoder(observations, train=False)
return observations
class SACObservationEncoder(nn.Module): class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations.""" """Encode image and/or state vector observations.
TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders.
"""
def __init__(self, config: SACConfig): def __init__(self, config: SACConfig):
"""
Creates encoders for pixel and/or state modalities.
"""
super().__init__() super().__init__()
self.config = config self.config = config
if "observation.image" in config.input_shapes:
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
)
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
over all features.
"""
feat = []
# Concatenate all images along the channel dimension.
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
for image_key in image_keys:
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]))
if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
return torch.stack(feat, dim=0).mean(0)
class LagrangeMultiplier(nn.Module):
def __init__(
self,
init_value: float = 1.0,
constraint_shape: Sequence[int] = (),
device: str = "cuda"
):
super().__init__()
self.device = torch.device(device)
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
# Initialize the Lagrange multiplier as a parameter
self.lagrange = nn.Parameter(
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
)
self.to(self.device)
def forward(
self,
lhs: Optional[torch.Tensor] = None,
rhs: Optional[torch.Tensor] = None
) -> torch.Tensor:
# Get the multiplier value based on parameterization
multiplier = torch.nn.functional.softplus(self.lagrange)
# Return the raw multiplier if no constraint values provided
if lhs is None:
return multiplier
# Move inputs to device
lhs = lhs.to(self.device)
if rhs is not None:
rhs = rhs.to(self.device)
# Use the multiplier to compute the Lagrange penalty
if rhs is None:
rhs = torch.zeros_like(lhs, device=self.device)
diff = lhs - rhs
assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
return multiplier * diff
# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where:
# 1. The base distribution is a diagonal multivariate normal distribution
# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
def __init__(
self,
loc: torch.Tensor,
scale_diag: torch.Tensor,
low: Optional[torch.Tensor] = None,
high: Optional[torch.Tensor] = None,
):
# Create base normal distribution
base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag)
# Create list of transforms
transforms = []
# Add tanh transform
transforms.append(torch.distributions.transforms.TanhTransform())
# Add rescaling transform if bounds are provided
if low is not None and high is not None:
transforms.append(
torch.distributions.transforms.AffineTransform(
loc=(high + low) / 2,
scale=(high - low) / 2
)
)
# Initialize parent class
super().__init__(
base_distribution=base_distribution,
transforms=transforms
)
# Store parameters
self.loc = loc
self.scale_diag = scale_diag
self.low = low
self.high = high
def mode(self) -> torch.Tensor:
"""Get the mode of the transformed distribution"""
# The mode of a normal distribution is its mean
mode = self.loc
# Apply transforms
for transform in self.transforms:
mode = transform(mode)
return mode
def rsample(self, sample_shape=torch.Size()) -> torch.Tensor:
"""
Reparameterized sample from the distribution
"""
# Sample from base distribution
x = self.base_dist.rsample(sample_shape)
# Apply transforms
for transform in self.transforms:
x = transform(x)
return x
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Compute log probability of a value
Includes the log det jacobian for the transforms
"""
# Initialize log prob
log_prob = torch.zeros_like(value[..., 0])
# Inverse transforms to get back to normal distribution
q = value
for transform in reversed(self.transforms):
q = transform.inv(q)
log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q))
# Add base distribution log prob
log_prob = log_prob + self.base_dist.log_prob(q).sum(-1)
return log_prob
def sample_and_log_prob(self, sample_shape=torch.Size()) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sample from the distribution and compute log probability
"""
x = self.rsample(sample_shape)
log_prob = self.log_prob(x)
return x, log_prob
def entropy(self) -> torch.Tensor:
"""
Compute entropy of the distribution
"""
# Start with base distribution entropy
entropy = self.base_dist.entropy().sum(-1)
# Add log det jacobian for each transform
x = self.rsample()
for transform in self.transforms:
entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
x = transform(x)
return entropy
def create_critic_ensemble(critic_class, num_critics: int, device: str = "cuda") -> nn.ModuleList:
"""Creates an ensemble of critic networks"""
critics = nn.ModuleList([critic_class() for _ in range(num_critics)])
return critics.to(device)
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
# borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
Args:
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
(B, *), where * is any number of dimensions.
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
can be more than 1 dimensions, generally different from *.
Returns:
A return value from the callable reshaped to (**, *).
"""
if image_tensor.ndim == 4:
return fn(image_tensor)
start_dims = image_tensor.shape[:-3]
inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))