From 214beec994132090ab076b42cdb92669c18732b3 Mon Sep 17 00:00:00 2001 From: KeWang Date: Tue, 17 Dec 2024 13:26:17 +0000 Subject: [PATCH] Port SAC WIP (#581) Co-authored-by: KeWang1017 --- .../common/policies/sac/configuration_sac.py | 39 ++ lerobot/common/policies/sac/modeling_sac.py | 508 +++++++++++++++++- 2 files changed, 541 insertions(+), 6 deletions(-) create mode 100644 lerobot/common/policies/sac/configuration_sac.py diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py new file mode 100644 index 00000000..441b3566 --- /dev/null +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + + +@dataclass +class SACConfig: + discount = 0.99 + temperature_init = 1.0 + num_critics = 2 + critic_lr = 3e-4 + actor_lr = 3e-4 + critic_network_kwargs = { + "hidden_dims": [256, 256], + "activate_final": True, + } + actor_network_kwargs = { + "hidden_dims": [256, 256], + "activate_final": True, + } + policy_kwargs = { + "tanh_squash_distribution": True, + "std_parameterization": "uniform", + } diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index fb2e5542..9ea9449d 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -15,7 +15,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +# TODO: (1) better device management + from collections import deque +from copy import deepcopy +from functools import partial import einops @@ -27,6 +31,10 @@ from torch import Tensor from huggingface_hub import PyTorchModelHubMixin from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.sac.configuration_sac import SACConfig +import numpy as np +from typing import Callable, Optional, Tuple, Sequence + + class SACPolicy( nn.Module, @@ -58,12 +66,27 @@ class SACPolicy( self.unnormalize_outputs = Unnormalize( config.output_shapes, config.output_normalization_modes, dataset_stats ) + encoder = SACObservationEncoder(config) + # Define networks + critic_nets = [] + for _ in range(config.num_critics): + critic_net = Critic( + encoder=encoder, + network=MLP(**config.critic_network_kwargs) + ) + critic_nets.append(critic_net) - self.critic_ensemble = ... - self.critic_target = ... - self.actor_network = ... + self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics) + self.critic_target = deepcopy(self.critic_ensemble) - self.temperature = ... + self.actor_network = Policy( + encoder=encoder, + network=MLP(**config.actor_network_kwargs), + action_dim=config.output_shapes["action"][0], + **config.policy_kwargs + ) + + self.temperature = LagrangeMultiplier(init_value=config.temperature_init) def reset(self): """ @@ -178,10 +201,483 @@ class SACPolicy( #for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()): # target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight) + +class MLP(nn.Module): + def __init__( + self, + config: SACConfig, + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: Optional[float] = None, + ): + super().__init__() + self.activate_final = config.activate_final + layers = [] + + for i, size in enumerate(config.network_hidden_dims): + layers.append(nn.Linear(config.network_hidden_dims[i-1] if i > 0 else config.network_hidden_dims[0], size)) + + if i + 1 < len(config.network_hidden_dims) or activate_final: + if dropout_rate is not None and dropout_rate > 0: + layers.append(nn.Dropout(p=dropout_rate)) + layers.append(nn.LayerNorm(size)) + layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) + + self.net = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor: + # in training mode or not. TODO: find better way to do this + self.train(train) + return self.net(x) + + +class Critic(nn.Module): + def __init__( + self, + encoder: Optional[nn.Module], + network: nn.Module, + init_final: Optional[float] = None, + activate_final: bool = False, + device: str = "cuda" + ): + super().__init__() + self.device = torch.device(device) + self.encoder = encoder + self.network = network + self.init_final = init_final + self.activate_final = activate_final + + # Output layer + if init_final is not None: + if self.activate_final: + self.output_layer = nn.Linear(network.net[-3].out_features, 1) + else: + self.output_layer = nn.Linear(network.net[-2].out_features, 1) + nn.init.uniform_(self.output_layer.weight, -init_final, init_final) + nn.init.uniform_(self.output_layer.bias, -init_final, init_final) + else: + if self.activate_final: + self.output_layer = nn.Linear(network.net[-3].out_features, 1) + else: + self.output_layer = nn.Linear(network.net[-2].out_features, 1) + orthogonal_init()(self.output_layer.weight) + + self.to(self.device) + + def forward( + self, + observations: torch.Tensor, + actions: torch.Tensor, + train: bool = False + ) -> torch.Tensor: + self.train(train) + + observations = observations.to(self.device) + actions = actions.to(self.device) + + if self.encoder is not None: + obs_enc = self.encoder(observations) + else: + obs_enc = observations + + inputs = torch.cat([obs_enc, actions], dim=-1) + x = self.network(inputs) + value = self.output_layer(x) + return value.squeeze(-1) + + def q_value_ensemble( + self, + observations: torch.Tensor, + actions: torch.Tensor, + train: bool = False + ) -> torch.Tensor: + observations = observations.to(self.device) + actions = actions.to(self.device) + + if len(actions.shape) == 3: # [batch_size, num_actions, action_dim] + batch_size, num_actions = actions.shape[:2] + obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1) + obs_flat = obs_expanded.reshape(-1, observations.shape[-1]) + actions_flat = actions.reshape(-1, actions.shape[-1]) + q_values = self(obs_flat, actions_flat, train) + return q_values.reshape(batch_size, num_actions) + else: + return self(observations, actions, train) + + +class Policy(nn.Module): + def __init__( + self, + encoder: Optional[nn.Module], + network: nn.Module, + action_dim: int, + std_parameterization: str = "exp", + std_min: float = 1e-5, + std_max: float = 10.0, + tanh_squash_distribution: bool = False, + fixed_std: Optional[torch.Tensor] = None, + init_final: Optional[float] = None, + activate_final: bool = False, + device: str = "cuda" + ): + super().__init__() + self.device = torch.device(device) + self.encoder = encoder + self.network = network + self.action_dim = action_dim + self.std_parameterization = std_parameterization + self.std_min = std_min + self.std_max = std_max + self.tanh_squash_distribution = tanh_squash_distribution + self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None + self.activate_final = activate_final + + # Mean layer + if self.activate_final: + self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim) + else: + self.mean_layer = nn.Linear(network.net[-2].out_features, action_dim) + if init_final is not None: + nn.init.uniform_(self.mean_layer.weight, -init_final, init_final) + nn.init.uniform_(self.mean_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.mean_layer.weight) + + # Standard deviation layer or parameter + if fixed_std is None: + if std_parameterization == "uniform": + self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device)) + else: + if self.activate_final: + self.std_layer = nn.Linear(network.net[-3].out_features, action_dim) + else: + self.std_layer = nn.Linear(network.net[-2].out_features, action_dim) + if init_final is not None: + nn.init.uniform_(self.std_layer.weight, -init_final, init_final) + nn.init.uniform_(self.std_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.std_layer.weight) + + self.to(self.device) + + def forward( + self, + observations: torch.Tensor, + temperature: float = 1.0, + train: bool = False, + non_squash_distribution: bool = False + ) -> torch.distributions.Distribution: + self.train(train) + + # Encode observations if encoder exists + if self.encoder is not None: + with torch.set_grad_enabled(train): + obs_enc = self.encoder(observations, train=train) + else: + obs_enc = observations + # Get network outputs + outputs = self.network(obs_enc) + means = self.mean_layer(outputs) + + # Compute standard deviations + if self.fixed_std is None: + if self.std_parameterization == "exp": + log_stds = self.std_layer(outputs) + stds = torch.exp(log_stds) + elif self.std_parameterization == "softplus": + stds = torch.nn.functional.softplus(self.std_layer(outputs)) + elif self.std_parameterization == "uniform": + stds = torch.exp(self.log_stds).expand_as(means) + else: + raise ValueError( + f"Invalid std_parameterization: {self.std_parameterization}" + ) + else: + assert self.std_parameterization == "fixed" + stds = self.fixed_std.expand_as(means) + + # Clip standard deviations and scale with temperature + temperature = torch.tensor(temperature, device=self.device) + stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature) + + # Create distribution + if self.tanh_squash_distribution and not non_squash_distribution: + distribution = TanhMultivariateNormalDiag( + loc=means, + scale_diag=stds, + ) + else: + distribution = torch.distributions.Normal( + loc=means, + scale=stds, + ) + + return distribution + + def get_features(self, observations: torch.Tensor) -> torch.Tensor: + """Get encoded features from observations""" + observations = observations.to(self.device) + if self.encoder is not None: + with torch.no_grad(): + return self.encoder(observations, train=False) + return observations + + class SACObservationEncoder(nn.Module): - """Encode image and/or state vector observations.""" + """Encode image and/or state vector observations. + TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders. + """ def __init__(self, config: SACConfig): - + """ + Creates encoders for pixel and/or state modalities. + """ super().__init__() self.config = config + + if "observation.image" in config.input_shapes: + self.image_enc_layers = nn.Sequential( + nn.Conv2d( + config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2 + ), + nn.ReLU(), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), + nn.ReLU(), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.ReLU(), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.ReLU(), + ) + dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) + with torch.inference_mode(): + out_shape = self.image_enc_layers(dummy_batch).shape[1:] + self.image_enc_layers.extend( + nn.Sequential( + nn.Flatten(), + nn.Linear(np.prod(out_shape), config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Tanh(), + ) + ) + if "observation.state" in config.input_shapes: + self.state_enc_layers = nn.Sequential( + nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim), + nn.ELU(), + nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Tanh(), + ) + if "observation.environment_state" in config.input_shapes: + self.env_state_enc_layers = nn.Sequential( + nn.Linear( + config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim + ), + nn.ELU(), + nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Tanh(), + ) + + def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: + """Encode the image and/or state vector. + + Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken + over all features. + """ + feat = [] + # Concatenate all images along the channel dimension. + image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")] + for image_key in image_keys: + feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])) + if "observation.environment_state" in self.config.input_shapes: + feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) + if "observation.state" in self.config.input_shapes: + feat.append(self.state_enc_layers(obs_dict["observation.state"])) + return torch.stack(feat, dim=0).mean(0) + + +class LagrangeMultiplier(nn.Module): + def __init__( + self, + init_value: float = 1.0, + constraint_shape: Sequence[int] = (), + device: str = "cuda" + ): + super().__init__() + self.device = torch.device(device) + init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1) + + # Initialize the Lagrange multiplier as a parameter + self.lagrange = nn.Parameter( + torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device) + ) + + self.to(self.device) + + def forward( + self, + lhs: Optional[torch.Tensor] = None, + rhs: Optional[torch.Tensor] = None + ) -> torch.Tensor: + # Get the multiplier value based on parameterization + multiplier = torch.nn.functional.softplus(self.lagrange) + + # Return the raw multiplier if no constraint values provided + if lhs is None: + return multiplier + + # Move inputs to device + lhs = lhs.to(self.device) + if rhs is not None: + rhs = rhs.to(self.device) + + # Use the multiplier to compute the Lagrange penalty + if rhs is None: + rhs = torch.zeros_like(lhs, device=self.device) + + diff = lhs - rhs + + assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}" + + return multiplier * diff + + +# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where: +# 1. The base distribution is a diagonal multivariate normal distribution +# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1 +# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation +# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces +class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution): + def __init__( + self, + loc: torch.Tensor, + scale_diag: torch.Tensor, + low: Optional[torch.Tensor] = None, + high: Optional[torch.Tensor] = None, + ): + # Create base normal distribution + base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag) + + # Create list of transforms + transforms = [] + + # Add tanh transform + transforms.append(torch.distributions.transforms.TanhTransform()) + + # Add rescaling transform if bounds are provided + if low is not None and high is not None: + transforms.append( + torch.distributions.transforms.AffineTransform( + loc=(high + low) / 2, + scale=(high - low) / 2 + ) + ) + + # Initialize parent class + super().__init__( + base_distribution=base_distribution, + transforms=transforms + ) + + # Store parameters + self.loc = loc + self.scale_diag = scale_diag + self.low = low + self.high = high + + def mode(self) -> torch.Tensor: + """Get the mode of the transformed distribution""" + # The mode of a normal distribution is its mean + mode = self.loc + + # Apply transforms + for transform in self.transforms: + mode = transform(mode) + + return mode + + def rsample(self, sample_shape=torch.Size()) -> torch.Tensor: + """ + Reparameterized sample from the distribution + """ + # Sample from base distribution + x = self.base_dist.rsample(sample_shape) + + # Apply transforms + for transform in self.transforms: + x = transform(x) + + return x + + def log_prob(self, value: torch.Tensor) -> torch.Tensor: + """ + Compute log probability of a value + Includes the log det jacobian for the transforms + """ + # Initialize log prob + log_prob = torch.zeros_like(value[..., 0]) + + # Inverse transforms to get back to normal distribution + q = value + for transform in reversed(self.transforms): + q = transform.inv(q) + log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q)) + + # Add base distribution log prob + log_prob = log_prob + self.base_dist.log_prob(q).sum(-1) + + return log_prob + + def sample_and_log_prob(self, sample_shape=torch.Size()) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Sample from the distribution and compute log probability + """ + x = self.rsample(sample_shape) + log_prob = self.log_prob(x) + return x, log_prob + + def entropy(self) -> torch.Tensor: + """ + Compute entropy of the distribution + """ + # Start with base distribution entropy + entropy = self.base_dist.entropy().sum(-1) + + # Add log det jacobian for each transform + x = self.rsample() + for transform in self.transforms: + entropy = entropy + transform.log_abs_det_jacobian(x, transform(x)) + x = transform(x) + + return entropy + + +def create_critic_ensemble(critic_class, num_critics: int, device: str = "cuda") -> nn.ModuleList: + """Creates an ensemble of critic networks""" + critics = nn.ModuleList([critic_class() for _ in range(num_critics)]) + return critics.to(device) + + +def orthogonal_init(): + return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) + + +# borrowed from tdmpc +def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: + """Helper to temporarily flatten extra dims at the start of the image tensor. + + Args: + fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return + (B, *), where * is any number of dimensions. + image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and + can be more than 1 dimensions, generally different from *. + Returns: + A return value from the callable reshaped to (**, *). + """ + if image_tensor.ndim == 4: + return fn(image_tensor) + start_dims = image_tensor.shape[:-3] + inp = torch.flatten(image_tensor, end_dim=-4) + flat_out = fn(inp) + return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) +