372 lines
16 KiB
Python
372 lines
16 KiB
Python
from abc import abstractmethod
|
|
import torch
|
|
from typing import Any, Callable, Dict, List, Tuple, Union
|
|
|
|
from rsl_rl.algorithms.agent import Agent
|
|
from rsl_rl.env.vec_env import VecEnv
|
|
from rsl_rl.modules.network import Network
|
|
from rsl_rl.storage.storage import Dataset
|
|
from rsl_rl.utils.utils import environment_dimensions
|
|
from rsl_rl.utils.utils import squeeze_preserve_batch
|
|
|
|
|
|
class AbstractActorCritic(Agent):
|
|
_alg_features = dict(recurrent=False)
|
|
|
|
def __init__(
|
|
self,
|
|
env: VecEnv,
|
|
actor_activations: List[str] = ["relu", "relu", "relu", "linear"],
|
|
actor_hidden_dims: List[int] = [256, 256, 256],
|
|
actor_init_gain: float = 0.5,
|
|
actor_input_normalization: bool = False,
|
|
actor_recurrent_layers: int = 1,
|
|
actor_recurrent_module: str = Network.recurrent_module_lstm,
|
|
actor_recurrent_tf_context_length: int = 64,
|
|
actor_recurrent_tf_head_count: int = 8,
|
|
actor_shared_dims: int = None,
|
|
batch_count: int = 1,
|
|
batch_size: int = 1,
|
|
critic_activations: List[str] = ["relu", "relu", "relu", "linear"],
|
|
critic_hidden_dims: List[int] = [256, 256, 256],
|
|
critic_init_gain: float = 0.5,
|
|
critic_input_normalization: bool = False,
|
|
critic_recurrent_layers: int = 1,
|
|
critic_recurrent_module: str = Network.recurrent_module_lstm,
|
|
critic_recurrent_tf_context_length: int = 64,
|
|
critic_recurrent_tf_head_count: int = 8,
|
|
critic_shared_dims: int = None,
|
|
polyak: float = 0.995,
|
|
recurrent: bool = False,
|
|
return_steps: int = 1,
|
|
_actor_input_size_delta: int = 0,
|
|
_critic_input_size_delta: int = 0,
|
|
**kwargs,
|
|
):
|
|
"""Creates an actor critic agent.
|
|
|
|
Args:
|
|
env (VecEnv): A vectorized environment.
|
|
actor_activations (List[str]): A list of activation functions for the actor network.
|
|
actor_hidden_dims (List[str]): A list of layer sizes for the actor network.
|
|
actor_init_gain (float): Network initialization gain for actor.
|
|
actor_input_normalization (bool): Whether to empirically normalize inputs to the actor network.
|
|
actor_recurrent_layers (int): The number of recurrent layers to use for the actor network.
|
|
actor_recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules.
|
|
actor_shared_dims (int): The number of dimensions to share for an actor with multiple heads.
|
|
batch_count (int): The number of batches to process per update step.
|
|
batch_size (int): The size of each batch to process during the update step.
|
|
critic_activations (List[str]): A list of activation functions for the critic network.
|
|
critic_hidden_dims: (List[str]): A list of layer sizes for the critic network.
|
|
critic_init_gain (float): Network initialization gain for critic.
|
|
critic_input_normalization (bool): Whether to empirically normalize inputs to the critic network.
|
|
critic_recurrent_layers (int): The number of recurrent layers to use for the critic network.
|
|
critic_recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules.
|
|
critic_shared_dims (int): The number of dimensions to share for a critic with multiple heads.
|
|
polyak (float): The actor-critic target network polyak factor.
|
|
recurrent (bool): Whether to use recurrent actor and critic networks.
|
|
recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules.
|
|
recurrent_tf_context_length (int): The context length of the Transformer.
|
|
recurrent_tf_head_count (int): The head count of the Transformer.
|
|
return_steps (float): The number of steps over which to compute the returns (n-step return).
|
|
_actor_input_size_delta (int): The number of additional dimensions to add to the actor input.
|
|
_critic_input_size_delta (int): The number of additional dimensions to add to the critic input.
|
|
"""
|
|
assert (
|
|
self._alg_features["recurrent"] == True or not recurrent
|
|
), f"{self.__class__.__name__} does not support recurrent networks."
|
|
|
|
super().__init__(env, **kwargs)
|
|
|
|
self.actor: torch.nn.Module = None
|
|
self.actor_optimizer: torch.nn.Module = None
|
|
self.critic_optimizer: torch.nn.Module = None
|
|
self.critic: torch.nn.Module = None
|
|
|
|
self._batch_size = batch_size
|
|
self._batch_count = batch_count
|
|
self._polyak_factor = polyak
|
|
self._return_steps = return_steps
|
|
self._recurrent = recurrent
|
|
|
|
self._register_serializable(
|
|
"_batch_size", "_batch_count", "_discount_factor", "_polyak_factor", "_return_steps"
|
|
)
|
|
|
|
dimensions = environment_dimensions(self.env)
|
|
try:
|
|
actor_input_size = dimensions["actor_observations"]
|
|
critic_input_size = dimensions["critic_observations"]
|
|
except KeyError:
|
|
actor_input_size = dimensions["observations"]
|
|
critic_input_size = dimensions["observations"]
|
|
self._actor_input_size = actor_input_size + _actor_input_size_delta
|
|
self._critic_input_size = critic_input_size + self._action_size + _critic_input_size_delta
|
|
|
|
self._register_actor_network_kwargs(
|
|
activations=actor_activations,
|
|
hidden_dims=actor_hidden_dims,
|
|
init_gain=actor_init_gain,
|
|
input_normalization=actor_input_normalization,
|
|
recurrent=recurrent,
|
|
recurrent_layers=actor_recurrent_layers,
|
|
recurrent_module=actor_recurrent_module,
|
|
recurrent_tf_context_length=actor_recurrent_tf_context_length,
|
|
recurrent_tf_head_count=actor_recurrent_tf_head_count,
|
|
)
|
|
|
|
if actor_shared_dims is not None:
|
|
self._register_actor_network_kwargs(shared_dims=actor_shared_dims)
|
|
|
|
self._register_critic_network_kwargs(
|
|
activations=critic_activations,
|
|
hidden_dims=critic_hidden_dims,
|
|
init_gain=critic_init_gain,
|
|
input_normalization=critic_input_normalization,
|
|
recurrent=recurrent,
|
|
recurrent_layers=critic_recurrent_layers,
|
|
recurrent_module=critic_recurrent_module,
|
|
recurrent_tf_context_length=critic_recurrent_tf_context_length,
|
|
recurrent_tf_head_count=critic_recurrent_tf_head_count,
|
|
)
|
|
|
|
if critic_shared_dims is not None:
|
|
self._register_critic_network_kwargs(shared_dims=critic_shared_dims)
|
|
|
|
self._register_serializable(
|
|
"_actor_input_size", "_actor_network_kwargs", "_critic_input_size", "_critic_network_kwargs"
|
|
)
|
|
|
|
# For computing n-step returns using prior transitions.
|
|
self._stored_dataset = []
|
|
|
|
def export_onnx(self) -> Tuple[torch.nn.Module, torch.Tensor, Dict]:
|
|
self.eval_mode()
|
|
|
|
class ONNXActor(torch.nn.Module):
|
|
def __init__(self, model: torch.nn.Module):
|
|
super().__init__()
|
|
|
|
self.model = model
|
|
|
|
def forward(self, x: torch.Tensor, hidden_state: Tuple[torch.Tensor, torch.Tensor] = None):
|
|
if hidden_state is None:
|
|
return self.model(x)
|
|
|
|
data = self.model(x, hidden_state=hidden_state)
|
|
hidden_state = self.model.last_hidden_state
|
|
|
|
return data, hidden_state
|
|
|
|
model = ONNXActor(self.actor)
|
|
kwargs = dict(
|
|
export_params=True,
|
|
opset_version=11,
|
|
verbose=True,
|
|
dynamic_axes={},
|
|
)
|
|
|
|
kwargs["input_names"] = ["observations"]
|
|
kwargs["output_names"] = ["actions"]
|
|
|
|
args = torch.zeros(1, self._actor_input_size)
|
|
|
|
if self.actor.recurrent:
|
|
hidden_state = (
|
|
torch.zeros(self.actor._features[0].num_layers, 1, self.actor._features[0].hidden_size),
|
|
torch.zeros(self.actor._features[0].num_layers, 1, self.actor._features[0].hidden_size),
|
|
)
|
|
args = (args, {"hidden_state": hidden_state})
|
|
|
|
return model, args, kwargs
|
|
|
|
def draw_random_actions(self, obs: torch.Tensor, env_info: Dict[str, Any]) -> Tuple[torch.Tensor, Dict]:
|
|
actions, data = super().draw_random_actions(obs, env_info)
|
|
|
|
actor_obs, critic_obs = self._process_observations(obs, env_info)
|
|
data.update({"actor_observations": actor_obs.clone(), "critic_observations": critic_obs.clone()})
|
|
|
|
return actions, data
|
|
|
|
def get_inference_policy(self, device=None) -> Callable:
|
|
self.to(device)
|
|
self.eval_mode()
|
|
|
|
if self.actor.recurrent:
|
|
self.actor.reset_full_hidden_state(batch_size=self.env.num_envs)
|
|
|
|
if self.critic.recurrent:
|
|
self.critic.reset_full_hidden_state(batch_size=self.env.num_envs)
|
|
|
|
def policy(obs, env_info=None):
|
|
with torch.inference_mode():
|
|
obs, _ = self._process_observations(obs, env_info)
|
|
|
|
actions = self._process_actions(self.actor.forward(obs))
|
|
|
|
return actions
|
|
|
|
return policy
|
|
|
|
def process_transition(
|
|
self,
|
|
observations: torch.Tensor,
|
|
environment_info: Dict[str, Any],
|
|
actions: torch.Tensor,
|
|
rewards: torch.Tensor,
|
|
next_observations: torch.Tensor,
|
|
next_environment_info: torch.Tensor,
|
|
dones: torch.Tensor,
|
|
data: Dict[str, torch.Tensor],
|
|
) -> Dict[str, torch.Tensor]:
|
|
if "actor_observations" in data and "critic_observations" in data:
|
|
actor_obs, critic_obs = data["actor_observations"], data["critic_observations"]
|
|
else:
|
|
actor_obs, critic_obs = self._process_observations(observations, environment_info)
|
|
|
|
if "next_actor_observations" in data and "next_critic_observations" in data:
|
|
next_actor_obs, next_critic_obs = data["next_actor_observations"], data["next_critic_observations"]
|
|
else:
|
|
next_actor_obs, next_critic_obs = self._process_observations(next_observations, next_environment_info)
|
|
|
|
transition = {
|
|
"actions": actions,
|
|
"actor_observations": actor_obs,
|
|
"critic_observations": critic_obs,
|
|
"dones": dones,
|
|
"next_actor_observations": next_actor_obs,
|
|
"next_critic_observations": next_critic_obs,
|
|
"rewards": squeeze_preserve_batch(rewards),
|
|
"timeouts": self._extract_timeouts(next_environment_info),
|
|
}
|
|
transition.update(data)
|
|
|
|
for key, value in transition.items():
|
|
transition[key] = value.detach().clone()
|
|
|
|
return transition
|
|
|
|
@property
|
|
def recurrent(self) -> bool:
|
|
return self._recurrent
|
|
|
|
def register_terminations(self, terminations: torch.Tensor) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
|
|
with torch.inference_mode():
|
|
self.storage.append(self._process_dataset(dataset))
|
|
|
|
def _critic_input(self, observations, actions) -> torch.Tensor:
|
|
"""Combines observations and actions into a tensor that can be fed into the critic network.
|
|
|
|
Args:
|
|
observations (torch.Tensor): The critic observations.
|
|
actions (torch.Tensor): The actions computed by the actor.
|
|
Returns:
|
|
A torch.Tensor of inputs for the critic network.
|
|
"""
|
|
return torch.cat((observations, actions), dim=-1)
|
|
|
|
def _extract_timeouts(self, next_environment_info):
|
|
"""Extracts timeout information from the transition next state information dictionary.
|
|
|
|
Args:
|
|
next_environment_info (Dict[str, Any]): The transition next state information dictionary.
|
|
Returns:
|
|
A torch.Tensor vector of actor timeouts.
|
|
"""
|
|
if "time_outs" not in next_environment_info:
|
|
return torch.zeros(self.env.num_envs, device=self.device)
|
|
|
|
timeouts = squeeze_preserve_batch(next_environment_info["time_outs"].to(self.device))
|
|
|
|
return timeouts
|
|
|
|
def _process_dataset(self, dataset: Dataset) -> Dataset:
|
|
"""Processes a dataset before it is added to the replay memory.
|
|
|
|
Handles n-step returns and timeouts.
|
|
|
|
TODO: This function seems to be a bottleneck in the training pipeline - speed it up!
|
|
|
|
Args:
|
|
dataset (Dataset): The dataset to process.
|
|
Returns:
|
|
A Dataset object containing the processed data.
|
|
"""
|
|
assert len(dataset) >= self._return_steps
|
|
|
|
dataset = self._stored_dataset + dataset
|
|
length = len(dataset) - self._return_steps + 1
|
|
self._stored_dataset = dataset[length:]
|
|
|
|
for idx in range(len(dataset) - self._return_steps + 1):
|
|
dead_idx = torch.zeros_like(dataset[idx]["dones"])
|
|
rewards = torch.zeros_like(dataset[idx]["rewards"])
|
|
|
|
for k in range(self._return_steps):
|
|
data = dataset[idx + k]
|
|
alive_idx = (dead_idx == 0).nonzero()
|
|
critic_predictions = self.critic.forward(
|
|
self._critic_input(
|
|
data["critic_observations"].clone().to(self.device),
|
|
data["actions"].clone().to(self.device),
|
|
)
|
|
)
|
|
rewards[alive_idx] += self._discount_factor**k * data["rewards"][alive_idx]
|
|
rewards[alive_idx] += (
|
|
self._discount_factor ** (k + 1) * data["timeouts"][alive_idx] * critic_predictions[alive_idx]
|
|
)
|
|
dead_idx += data["dones"]
|
|
dead_idx += data["timeouts"]
|
|
|
|
dataset[idx]["rewards"] = rewards
|
|
|
|
return dataset[:length]
|
|
|
|
def _process_observations(
|
|
self, observations: torch.Tensor, environment_info: Dict[str, Any] = None
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
"""Processes observations returned by the environment to extract actor and critic observations.
|
|
|
|
Args:
|
|
observations (torch.Tensor): normal environment observations.
|
|
environment_info (Dict[str, Any]): A dictionary of additional environment information.
|
|
Returns:
|
|
A tuple containing two torch.Tensors with actor and critic observations, respectively.
|
|
"""
|
|
try:
|
|
critic_obs = environment_info["observations"]["critic"]
|
|
except (KeyError, TypeError):
|
|
critic_obs = observations
|
|
|
|
actor_obs, critic_obs = observations.to(self.device), critic_obs.to(self.device)
|
|
|
|
return actor_obs, critic_obs
|
|
|
|
def _register_actor_network_kwargs(self, **kwargs) -> None:
|
|
"""Function to configure actor network in child classes before calling super().__init__()."""
|
|
if not hasattr(self, "_actor_network_kwargs"):
|
|
self._actor_network_kwargs = dict()
|
|
|
|
self._actor_network_kwargs.update(**kwargs)
|
|
|
|
def _register_critic_network_kwargs(self, **kwargs) -> None:
|
|
"""Function to configure critic network in child classes before calling super().__init__()."""
|
|
if not hasattr(self, "_critic_network_kwargs"):
|
|
self._critic_network_kwargs = dict()
|
|
|
|
self._critic_network_kwargs.update(**kwargs)
|
|
|
|
def _update_target(self, online: torch.nn.Module, target: torch.nn.Module) -> None:
|
|
"""Updates the target network using the polyak factor.
|
|
|
|
Args:
|
|
online (torch.nn.Module): The online network.
|
|
target (torch.nn.Module): The target network.
|
|
"""
|
|
for op, tp in zip(online.parameters(), target.parameters()):
|
|
tp.data.copy_((1.0 - self._polyak_factor) * op.data + self._polyak_factor * tp.data)
|