velocity planner added
This commit is contained in:
parent
4f0f062efa
commit
44e8cbf692
|
@ -32,6 +32,7 @@ from collections import defaultdict
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
import copy
|
||||||
|
|
||||||
from rsl_rl.modules import ActorCritic
|
from rsl_rl.modules import ActorCritic
|
||||||
from rsl_rl.storage import RolloutStorage
|
from rsl_rl.storage import RolloutStorage
|
||||||
|
@ -40,6 +41,7 @@ class PPO:
|
||||||
actor_critic: ActorCritic
|
actor_critic: ActorCritic
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
actor_critic,
|
actor_critic,
|
||||||
|
velocity_planner,
|
||||||
num_learning_epochs=1,
|
num_learning_epochs=1,
|
||||||
num_mini_batches=1,
|
num_mini_batches=1,
|
||||||
clip_param=0.2,
|
clip_param=0.2,
|
||||||
|
@ -55,6 +57,7 @@ class PPO:
|
||||||
schedule="fixed",
|
schedule="fixed",
|
||||||
desired_kl=0.01,
|
desired_kl=0.01,
|
||||||
device='cpu',
|
device='cpu',
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
|
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -70,6 +73,11 @@ class PPO:
|
||||||
self.optimizer = getattr(optim, optimizer_class_name)(self.actor_critic.parameters(), lr=learning_rate)
|
self.optimizer = getattr(optim, optimizer_class_name)(self.actor_critic.parameters(), lr=learning_rate)
|
||||||
self.transition = RolloutStorage.Transition()
|
self.transition = RolloutStorage.Transition()
|
||||||
|
|
||||||
|
# Velocity Planner
|
||||||
|
self.velocity_planner = velocity_planner
|
||||||
|
self.velocity_optimizer = getattr(optim, optimizer_class_name)(self.velocity_planner.parameters(), lr=learning_rate)
|
||||||
|
self.lin_vel_x = kwargs.get('lin_vel_x', None)
|
||||||
|
|
||||||
# PPO parameters
|
# PPO parameters
|
||||||
self.clip_param = clip_param
|
self.clip_param = clip_param
|
||||||
self.num_learning_epochs = num_learning_epochs
|
self.num_learning_epochs = num_learning_epochs
|
||||||
|
@ -98,7 +106,12 @@ class PPO:
|
||||||
if self.actor_critic.is_recurrent:
|
if self.actor_critic.is_recurrent:
|
||||||
self.transition.hidden_states = self.actor_critic.get_hidden_states()
|
self.transition.hidden_states = self.actor_critic.get_hidden_states()
|
||||||
# Compute the actions and values
|
# Compute the actions and values
|
||||||
self.transition.actions = self.actor_critic.act(obs).detach()
|
vel_obs = torch.cat([obs[:, :9], obs[:, 12:]], dim=1)
|
||||||
|
velocity = self.velocity_planner(vel_obs)
|
||||||
|
if self.lin_vel_x is not None:
|
||||||
|
velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
||||||
|
|
||||||
|
self.transition.actions = self.actor_critic.act(obs, velocity=velocity)[0].detach()
|
||||||
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
|
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
|
||||||
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
|
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
|
||||||
self.transition.action_mean = self.actor_critic.action_mean.detach()
|
self.transition.action_mean = self.actor_critic.action_mean.detach()
|
||||||
|
@ -106,7 +119,7 @@ class PPO:
|
||||||
# need to record obs and critic_obs before env.step()
|
# need to record obs and critic_obs before env.step()
|
||||||
self.transition.observations = obs
|
self.transition.observations = obs
|
||||||
self.transition.critic_observations = critic_obs
|
self.transition.critic_observations = critic_obs
|
||||||
return self.transition.actions
|
return self.transition.actions, velocity.squeeze().detach()
|
||||||
|
|
||||||
def process_env_step(self, rewards, dones, infos):
|
def process_env_step(self, rewards, dones, infos):
|
||||||
self.transition.rewards = rewards.clone()
|
self.transition.rewards = rewards.clone()
|
||||||
|
@ -146,9 +159,11 @@ class PPO:
|
||||||
|
|
||||||
# Gradient step
|
# Gradient step
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
self.velocity_optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
|
nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
self.velocity_optimizer.step()
|
||||||
|
|
||||||
num_updates = self.num_learning_epochs * self.num_mini_batches
|
num_updates = self.num_learning_epochs * self.num_mini_batches
|
||||||
for k in mean_losses.keys():
|
for k in mean_losses.keys():
|
||||||
|
@ -162,9 +177,16 @@ class PPO:
|
||||||
return mean_losses, average_stats
|
return mean_losses, average_stats
|
||||||
|
|
||||||
def compute_losses(self, minibatch):
|
def compute_losses(self, minibatch):
|
||||||
self.actor_critic.act(minibatch.obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0])
|
obs = copy.deepcopy(minibatch.obs)
|
||||||
|
# print(obs.shape)
|
||||||
|
vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
|
||||||
|
# print(vel_obs.shape)
|
||||||
|
velocity = self.velocity_planner(vel_obs)
|
||||||
|
if self.lin_vel_x is not None:
|
||||||
|
velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
||||||
|
self.actor_critic.act(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0], velocity=velocity)
|
||||||
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(minibatch.actions)
|
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(minibatch.actions)
|
||||||
value_batch = self.actor_critic.evaluate(minibatch.critic_obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[1])
|
value_batch = self.actor_critic.evaluate(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[1])
|
||||||
mu_batch = self.actor_critic.action_mean
|
mu_batch = self.actor_critic.action_mean
|
||||||
sigma_batch = self.actor_critic.action_std
|
sigma_batch = self.actor_critic.action_std
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -62,15 +62,6 @@ class ActorCriticRecurrent(ActorCritic):
|
||||||
|
|
||||||
activation = get_activation(activation)
|
activation = get_activation(activation)
|
||||||
|
|
||||||
self.velocity_planner = nn.Sequential(
|
|
||||||
nn.Linear(num_actor_obs-3, 256),
|
|
||||||
nn.ELU(),
|
|
||||||
nn.Linear(256, 128),
|
|
||||||
nn.ELU(),
|
|
||||||
nn.Linear(128, 1),
|
|
||||||
nn.Tanh()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.lin_vel_x = kwargs["lin_vel_x"]
|
self.lin_vel_x = kwargs["lin_vel_x"]
|
||||||
self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
|
self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
|
||||||
self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
|
self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
|
||||||
|
@ -82,16 +73,19 @@ class ActorCriticRecurrent(ActorCritic):
|
||||||
self.memory_a.reset(dones)
|
self.memory_a.reset(dones)
|
||||||
self.memory_c.reset(dones)
|
self.memory_c.reset(dones)
|
||||||
|
|
||||||
def act(self, observations, masks=None, hidden_states=None):
|
def act(self, observations, masks=None, hidden_states=None, velocity=None):
|
||||||
vel_obs = torch.cat([observations[:, :9], observations[:, 12:]], dim=1)
|
if velocity is not None:
|
||||||
velocity = self.velocity_planner(vel_obs)
|
observations[..., 9] = velocity.squeeze()
|
||||||
velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
# vel_obs = torch.cat([observations[:, :9], observations[:, 12:]], dim=1)
|
||||||
self.velocity = velocity
|
# velocity = self.velocity_planner(vel_obs)
|
||||||
observations[:, 9] = velocity
|
# velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
||||||
|
# observations[:, 9] = velocity
|
||||||
input_a = self.memory_a(observations, masks, hidden_states)
|
input_a = self.memory_a(observations, masks, hidden_states)
|
||||||
return super().act(input_a.squeeze(0))
|
return super().act(input_a.squeeze(0)), velocity.squeeze().detach()
|
||||||
|
|
||||||
def act_inference(self, observations):
|
def act_inference(self, observations, velocity=None):
|
||||||
|
if velocity is not None:
|
||||||
|
observations[:, 9] = velocity
|
||||||
input_a = self.memory_a(observations)
|
input_a = self.memory_a(observations)
|
||||||
return super().act_inference(input_a.squeeze(0))
|
return super().act_inference(input_a.squeeze(0))
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,7 @@ import statistics
|
||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
import rsl_rl.algorithms as algorithms
|
import rsl_rl.algorithms as algorithms
|
||||||
import rsl_rl.modules as modules
|
import rsl_rl.modules as modules
|
||||||
|
@ -61,8 +62,17 @@ class OnPolicyRunner:
|
||||||
self.policy_cfg,
|
self.policy_cfg,
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
|
||||||
|
velocity_planner = nn.Sequential(
|
||||||
|
nn.Linear(env.num_obs-3, 256),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Linear(256, 128),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Linear(128, 1)
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
alg_class = getattr(algorithms, self.cfg["algorithm_class_name"]) # PPO
|
alg_class = getattr(algorithms, self.cfg["algorithm_class_name"]) # PPO
|
||||||
self.alg: algorithms.PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg)
|
|
||||||
|
self.alg: algorithms.PPO = alg_class(actor_critic, velocity_planner, device=self.device, **self.alg_cfg)
|
||||||
|
|
||||||
self.num_steps_per_env = self.cfg["num_steps_per_env"]
|
self.num_steps_per_env = self.cfg["num_steps_per_env"]
|
||||||
self.save_interval = self.cfg["save_interval"]
|
self.save_interval = self.cfg["save_interval"]
|
||||||
|
@ -141,8 +151,8 @@ class OnPolicyRunner:
|
||||||
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))
|
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))
|
||||||
|
|
||||||
def rollout_step(self, obs, critic_obs):
|
def rollout_step(self, obs, critic_obs):
|
||||||
actions = self.alg.act(obs, critic_obs)
|
actions, velocity = self.alg.act(obs, critic_obs)
|
||||||
obs, privileged_obs, rewards, dones, infos = self.env.step(actions)
|
obs, privileged_obs, rewards, dones, infos = self.env.step(actions, velocity)
|
||||||
critic_obs = privileged_obs if privileged_obs is not None else obs
|
critic_obs = privileged_obs if privileged_obs is not None else obs
|
||||||
obs, critic_obs, rewards, dones = obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device)
|
obs, critic_obs, rewards, dones = obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device)
|
||||||
self.alg.process_env_step(rewards, dones, infos)
|
self.alg.process_env_step(rewards, dones, infos)
|
||||||
|
@ -229,6 +239,7 @@ class OnPolicyRunner:
|
||||||
run_state_dict = {
|
run_state_dict = {
|
||||||
'model_state_dict': self.alg.actor_critic.state_dict(),
|
'model_state_dict': self.alg.actor_critic.state_dict(),
|
||||||
'optimizer_state_dict': self.alg.optimizer.state_dict(),
|
'optimizer_state_dict': self.alg.optimizer.state_dict(),
|
||||||
|
'velocity_optimizer_state_dict': self.alg.velocity_optimizer.state_dict(),
|
||||||
'iter': self.current_learning_iteration,
|
'iter': self.current_learning_iteration,
|
||||||
'infos': infos,
|
'infos': infos,
|
||||||
}
|
}
|
||||||
|
@ -240,7 +251,9 @@ class OnPolicyRunner:
|
||||||
loaded_dict = torch.load(path)
|
loaded_dict = torch.load(path)
|
||||||
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
|
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
|
||||||
if load_optimizer and "optimizer_state_dict" in loaded_dict:
|
if load_optimizer and "optimizer_state_dict" in loaded_dict:
|
||||||
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'])
|
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'], )
|
||||||
|
if load_optimizer and "velocity_optimizer_state_dict" in loaded_dict:
|
||||||
|
self.alg.velocity_optimizer.load_state_dict(loaded_dict['velocity_optimizer_state_dict'], )
|
||||||
if "lr_scheduler_state_dict" in loaded_dict:
|
if "lr_scheduler_state_dict" in loaded_dict:
|
||||||
if not hasattr(self.alg, "lr_scheduler"):
|
if not hasattr(self.alg, "lr_scheduler"):
|
||||||
print("Warning: lr_scheduler_state_dict found in checkpoint but no lr_scheduler in algorithm. Ignoring.")
|
print("Warning: lr_scheduler_state_dict found in checkpoint but no lr_scheduler in algorithm. Ignoring.")
|
||||||
|
|
Loading…
Reference in New Issue