# License: see [LICENSE, LICENSES/rsl_rl/LICENSE] import torch import torch.nn as nn from params_proto import PrefixProto from torch.distributions import Normal class AC_Args(PrefixProto, cli=False): # policy init_noise_std = 1.0 actor_hidden_dims = [512, 256, 128] critic_hidden_dims = [512, 256, 128] activation = 'elu' # can be elu, relu, selu, crelu, lrelu, tanh, sigmoid adaptation_module_branch_hidden_dims = [[256, 32]] env_factor_encoder_branch_input_dims = [18] env_factor_encoder_branch_latent_dims = [18] env_factor_encoder_branch_hidden_dims = [[256, 128]] class ActorCritic(nn.Module): is_recurrent = False def __init__(self, num_obs, num_privileged_obs, num_obs_history, num_actions, **kwargs): if kwargs: print("ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str( [key for key in kwargs.keys()])) super().__init__() activation = get_activation(AC_Args.activation) for i, (branch_input_dim, branch_hidden_dims, branch_latent_dim) in enumerate( zip(AC_Args.env_factor_encoder_branch_input_dims, AC_Args.env_factor_encoder_branch_hidden_dims, AC_Args.env_factor_encoder_branch_latent_dims)): # Env factor encoder env_factor_encoder_layers = [] env_factor_encoder_layers.append(nn.Linear(branch_input_dim, branch_hidden_dims[0])) env_factor_encoder_layers.append(activation) for l in range(len(branch_hidden_dims)): if l == len(branch_hidden_dims) - 1: env_factor_encoder_layers.append( nn.Linear(branch_hidden_dims[l], branch_latent_dim)) else: env_factor_encoder_layers.append( nn.Linear(branch_hidden_dims[l], branch_hidden_dims[l + 1])) env_factor_encoder_layers.append(activation) self.env_factor_encoder = nn.Sequential(*env_factor_encoder_layers) self.add_module(f"encoder", self.env_factor_encoder) # Adaptation module for i, (branch_hidden_dims, branch_latent_dim) in enumerate(zip(AC_Args.adaptation_module_branch_hidden_dims, AC_Args.env_factor_encoder_branch_latent_dims)): adaptation_module_layers = [] adaptation_module_layers.append(nn.Linear(num_obs_history, branch_hidden_dims[0])) adaptation_module_layers.append(activation) for l in range(len(branch_hidden_dims)): if l == len(branch_hidden_dims) - 1: adaptation_module_layers.append( nn.Linear(branch_hidden_dims[l], branch_latent_dim)) else: adaptation_module_layers.append( nn.Linear(branch_hidden_dims[l], branch_hidden_dims[l + 1])) adaptation_module_layers.append(activation) self.adaptation_module = nn.Sequential(*adaptation_module_layers) self.add_module(f"adaptation_module", self.adaptation_module) total_latent_dim = int(torch.sum(torch.Tensor(AC_Args.env_factor_encoder_branch_latent_dims))) # Policy actor_layers = [] actor_layers.append(nn.Linear(total_latent_dim + num_obs, AC_Args.actor_hidden_dims[0])) actor_layers.append(activation) for l in range(len(AC_Args.actor_hidden_dims)): if l == len(AC_Args.actor_hidden_dims) - 1: actor_layers.append(nn.Linear(AC_Args.actor_hidden_dims[l], num_actions)) else: actor_layers.append(nn.Linear(AC_Args.actor_hidden_dims[l], AC_Args.actor_hidden_dims[l + 1])) actor_layers.append(activation) self.actor_body = nn.Sequential(*actor_layers) # Value function critic_layers = [] critic_layers.append(nn.Linear(total_latent_dim + num_obs, AC_Args.critic_hidden_dims[0])) critic_layers.append(activation) for l in range(len(AC_Args.critic_hidden_dims)): if l == len(AC_Args.critic_hidden_dims) - 1: critic_layers.append(nn.Linear(AC_Args.critic_hidden_dims[l], 1)) else: critic_layers.append(nn.Linear(AC_Args.critic_hidden_dims[l], AC_Args.critic_hidden_dims[l + 1])) critic_layers.append(activation) self.critic_body = nn.Sequential(*critic_layers) print(f"Environment Factor Encoder: {self.env_factor_encoder}") print(f"Adaptation Module: {self.adaptation_module}") print(f"Actor MLP: {self.actor_body}") print(f"Critic MLP: {self.critic_body}") # Action noise self.std = nn.Parameter(AC_Args.init_noise_std * torch.ones(num_actions)) self.distribution = None # disable args validation for speedup Normal.set_default_validate_args = False @staticmethod # not used at the moment def init_weights(sequential, scales): [torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))] def reset(self, dones=None): pass def forward(self): raise NotImplementedError @property def action_mean(self): return self.distribution.mean @property def action_std(self): return self.distribution.stddev @property def entropy(self): return self.distribution.entropy().sum(dim=-1) def update_distribution(self, observations, privileged_observations): latent = self.env_factor_encoder(privileged_observations) mean = self.actor_body(torch.cat((observations, latent), dim=-1)) self.distribution = Normal(mean, mean * 0. + self.std) def act(self, observations, privileged_observations, **kwargs): self.update_distribution(observations, privileged_observations) return self.distribution.sample() def get_actions_log_prob(self, actions): return self.distribution.log_prob(actions).sum(dim=-1) def act_expert(self, ob, policy_info={}): return self.act_teacher(ob["obs"], ob["privileged_obs"]) def act_inference(self, ob, policy_info={}): if ob["privileged_obs"] is not None: gt_latent = self.env_factor_encoder(ob["privileged_obs"]) policy_info["gt_latents"] = gt_latent.detach().cpu().numpy() return self.act_student(ob["obs"], ob["obs_history"]) def act_student(self, observations, observation_history, policy_info={}): latent = self.adaptation_module(observation_history) actions_mean = self.actor_body(torch.cat((observations, latent), dim=-1)) policy_info["latents"] = latent.detach().cpu().numpy() return actions_mean def act_teacher(self, observations, privileged_info, policy_info={}): latent = self.env_factor_encoder(privileged_info) actions_mean = self.actor_body(torch.cat((observations, latent), dim=-1)) policy_info["latents"] = latent.detach().cpu().numpy() return actions_mean def evaluate(self, critic_observations, privileged_observations, **kwargs): latent = self.env_factor_encoder(privileged_observations) value = self.critic_body(torch.cat((critic_observations, latent), dim=-1)) return value def get_activation(act_name): if act_name == "elu": return nn.ELU() elif act_name == "selu": return nn.SELU() elif act_name == "relu": return nn.ReLU() elif act_name == "crelu": return nn.ReLU() elif act_name == "lrelu": return nn.LeakyReLU() elif act_name == "tanh": return nn.Tanh() elif act_name == "sigmoid": return nn.Sigmoid() else: print("invalid activation function!") return None