velocity planner added

This commit is contained in:
Jerry Xu 2024-05-24 11:27:35 -04:00
parent 4f0f062efa
commit 44e8cbf692
3 changed files with 54 additions and 25 deletions

View File

@ -32,6 +32,7 @@ from collections import defaultdict
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import copy
from rsl_rl.modules import ActorCritic from rsl_rl.modules import ActorCritic
from rsl_rl.storage import RolloutStorage from rsl_rl.storage import RolloutStorage
@ -40,6 +41,7 @@ class PPO:
actor_critic: ActorCritic actor_critic: ActorCritic
def __init__(self, def __init__(self,
actor_critic, actor_critic,
velocity_planner,
num_learning_epochs=1, num_learning_epochs=1,
num_mini_batches=1, num_mini_batches=1,
clip_param=0.2, clip_param=0.2,
@ -55,6 +57,7 @@ class PPO:
schedule="fixed", schedule="fixed",
desired_kl=0.01, desired_kl=0.01,
device='cpu', device='cpu',
**kwargs
): ):
self.device = device self.device = device
@ -70,6 +73,11 @@ class PPO:
self.optimizer = getattr(optim, optimizer_class_name)(self.actor_critic.parameters(), lr=learning_rate) self.optimizer = getattr(optim, optimizer_class_name)(self.actor_critic.parameters(), lr=learning_rate)
self.transition = RolloutStorage.Transition() self.transition = RolloutStorage.Transition()
# Velocity Planner
self.velocity_planner = velocity_planner
self.velocity_optimizer = getattr(optim, optimizer_class_name)(self.velocity_planner.parameters(), lr=learning_rate)
self.lin_vel_x = kwargs.get('lin_vel_x', None)
# PPO parameters # PPO parameters
self.clip_param = clip_param self.clip_param = clip_param
self.num_learning_epochs = num_learning_epochs self.num_learning_epochs = num_learning_epochs
@ -98,7 +106,12 @@ class PPO:
if self.actor_critic.is_recurrent: if self.actor_critic.is_recurrent:
self.transition.hidden_states = self.actor_critic.get_hidden_states() self.transition.hidden_states = self.actor_critic.get_hidden_states()
# Compute the actions and values # Compute the actions and values
self.transition.actions = self.actor_critic.act(obs).detach() vel_obs = torch.cat([obs[:, :9], obs[:, 12:]], dim=1)
velocity = self.velocity_planner(vel_obs)
if self.lin_vel_x is not None:
velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
self.transition.actions = self.actor_critic.act(obs, velocity=velocity)[0].detach()
self.transition.values = self.actor_critic.evaluate(critic_obs).detach() self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach() self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
self.transition.action_mean = self.actor_critic.action_mean.detach() self.transition.action_mean = self.actor_critic.action_mean.detach()
@ -106,7 +119,7 @@ class PPO:
# need to record obs and critic_obs before env.step() # need to record obs and critic_obs before env.step()
self.transition.observations = obs self.transition.observations = obs
self.transition.critic_observations = critic_obs self.transition.critic_observations = critic_obs
return self.transition.actions return self.transition.actions, velocity.squeeze().detach()
def process_env_step(self, rewards, dones, infos): def process_env_step(self, rewards, dones, infos):
self.transition.rewards = rewards.clone() self.transition.rewards = rewards.clone()
@ -146,9 +159,11 @@ class PPO:
# Gradient step # Gradient step
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.velocity_optimizer.zero_grad()
loss.backward() loss.backward()
nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm) nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
self.optimizer.step() self.optimizer.step()
self.velocity_optimizer.step()
num_updates = self.num_learning_epochs * self.num_mini_batches num_updates = self.num_learning_epochs * self.num_mini_batches
for k in mean_losses.keys(): for k in mean_losses.keys():
@ -162,9 +177,16 @@ class PPO:
return mean_losses, average_stats return mean_losses, average_stats
def compute_losses(self, minibatch): def compute_losses(self, minibatch):
self.actor_critic.act(minibatch.obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0]) obs = copy.deepcopy(minibatch.obs)
# print(obs.shape)
vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
# print(vel_obs.shape)
velocity = self.velocity_planner(vel_obs)
if self.lin_vel_x is not None:
velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
self.actor_critic.act(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0], velocity=velocity)
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(minibatch.actions) actions_log_prob_batch = self.actor_critic.get_actions_log_prob(minibatch.actions)
value_batch = self.actor_critic.evaluate(minibatch.critic_obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[1]) value_batch = self.actor_critic.evaluate(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[1])
mu_batch = self.actor_critic.action_mean mu_batch = self.actor_critic.action_mean
sigma_batch = self.actor_critic.action_std sigma_batch = self.actor_critic.action_std
try: try:

View File

@ -62,15 +62,6 @@ class ActorCriticRecurrent(ActorCritic):
activation = get_activation(activation) activation = get_activation(activation)
self.velocity_planner = nn.Sequential(
nn.Linear(num_actor_obs-3, 256),
nn.ELU(),
nn.Linear(256, 128),
nn.ELU(),
nn.Linear(128, 1),
nn.Tanh()
)
self.lin_vel_x = kwargs["lin_vel_x"] self.lin_vel_x = kwargs["lin_vel_x"]
self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size) self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size) self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
@ -82,16 +73,19 @@ class ActorCriticRecurrent(ActorCritic):
self.memory_a.reset(dones) self.memory_a.reset(dones)
self.memory_c.reset(dones) self.memory_c.reset(dones)
def act(self, observations, masks=None, hidden_states=None): def act(self, observations, masks=None, hidden_states=None, velocity=None):
vel_obs = torch.cat([observations[:, :9], observations[:, 12:]], dim=1) if velocity is not None:
velocity = self.velocity_planner(vel_obs) observations[..., 9] = velocity.squeeze()
velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1]) # vel_obs = torch.cat([observations[:, :9], observations[:, 12:]], dim=1)
self.velocity = velocity # velocity = self.velocity_planner(vel_obs)
observations[:, 9] = velocity # velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
# observations[:, 9] = velocity
input_a = self.memory_a(observations, masks, hidden_states) input_a = self.memory_a(observations, masks, hidden_states)
return super().act(input_a.squeeze(0)) return super().act(input_a.squeeze(0)), velocity.squeeze().detach()
def act_inference(self, observations): def act_inference(self, observations, velocity=None):
if velocity is not None:
observations[:, 9] = velocity
input_a = self.memory_a(observations) input_a = self.memory_a(observations)
return super().act_inference(input_a.squeeze(0)) return super().act_inference(input_a.squeeze(0))

View File

@ -35,6 +35,7 @@ import statistics
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import torch import torch
import torch.nn as nn
import rsl_rl.algorithms as algorithms import rsl_rl.algorithms as algorithms
import rsl_rl.modules as modules import rsl_rl.modules as modules
@ -61,8 +62,17 @@ class OnPolicyRunner:
self.policy_cfg, self.policy_cfg,
).to(self.device) ).to(self.device)
velocity_planner = nn.Sequential(
nn.Linear(env.num_obs-3, 256),
nn.ELU(),
nn.Linear(256, 128),
nn.ELU(),
nn.Linear(128, 1)
).to(self.device)
alg_class = getattr(algorithms, self.cfg["algorithm_class_name"]) # PPO alg_class = getattr(algorithms, self.cfg["algorithm_class_name"]) # PPO
self.alg: algorithms.PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg)
self.alg: algorithms.PPO = alg_class(actor_critic, velocity_planner, device=self.device, **self.alg_cfg)
self.num_steps_per_env = self.cfg["num_steps_per_env"] self.num_steps_per_env = self.cfg["num_steps_per_env"]
self.save_interval = self.cfg["save_interval"] self.save_interval = self.cfg["save_interval"]
@ -141,8 +151,8 @@ class OnPolicyRunner:
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration))) self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))
def rollout_step(self, obs, critic_obs): def rollout_step(self, obs, critic_obs):
actions = self.alg.act(obs, critic_obs) actions, velocity = self.alg.act(obs, critic_obs)
obs, privileged_obs, rewards, dones, infos = self.env.step(actions) obs, privileged_obs, rewards, dones, infos = self.env.step(actions, velocity)
critic_obs = privileged_obs if privileged_obs is not None else obs critic_obs = privileged_obs if privileged_obs is not None else obs
obs, critic_obs, rewards, dones = obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device) obs, critic_obs, rewards, dones = obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device)
self.alg.process_env_step(rewards, dones, infos) self.alg.process_env_step(rewards, dones, infos)
@ -229,6 +239,7 @@ class OnPolicyRunner:
run_state_dict = { run_state_dict = {
'model_state_dict': self.alg.actor_critic.state_dict(), 'model_state_dict': self.alg.actor_critic.state_dict(),
'optimizer_state_dict': self.alg.optimizer.state_dict(), 'optimizer_state_dict': self.alg.optimizer.state_dict(),
'velocity_optimizer_state_dict': self.alg.velocity_optimizer.state_dict(),
'iter': self.current_learning_iteration, 'iter': self.current_learning_iteration,
'infos': infos, 'infos': infos,
} }
@ -240,7 +251,9 @@ class OnPolicyRunner:
loaded_dict = torch.load(path) loaded_dict = torch.load(path)
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict']) self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
if load_optimizer and "optimizer_state_dict" in loaded_dict: if load_optimizer and "optimizer_state_dict" in loaded_dict:
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict']) self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'], )
if load_optimizer and "velocity_optimizer_state_dict" in loaded_dict:
self.alg.velocity_optimizer.load_state_dict(loaded_dict['velocity_optimizer_state_dict'], )
if "lr_scheduler_state_dict" in loaded_dict: if "lr_scheduler_state_dict" in loaded_dict:
if not hasattr(self.alg, "lr_scheduler"): if not hasattr(self.alg, "lr_scheduler"):
print("Warning: lr_scheduler_state_dict found in checkpoint but no lr_scheduler in algorithm. Ignoring.") print("Warning: lr_scheduler_state_dict found in checkpoint but no lr_scheduler in algorithm. Ignoring.")