velocity planner added
This commit is contained in:
parent
4f0f062efa
commit
44e8cbf692
|
@ -32,6 +32,7 @@ from collections import defaultdict
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import copy
|
||||
|
||||
from rsl_rl.modules import ActorCritic
|
||||
from rsl_rl.storage import RolloutStorage
|
||||
|
@ -40,6 +41,7 @@ class PPO:
|
|||
actor_critic: ActorCritic
|
||||
def __init__(self,
|
||||
actor_critic,
|
||||
velocity_planner,
|
||||
num_learning_epochs=1,
|
||||
num_mini_batches=1,
|
||||
clip_param=0.2,
|
||||
|
@ -55,6 +57,7 @@ class PPO:
|
|||
schedule="fixed",
|
||||
desired_kl=0.01,
|
||||
device='cpu',
|
||||
**kwargs
|
||||
):
|
||||
|
||||
self.device = device
|
||||
|
@ -70,6 +73,11 @@ class PPO:
|
|||
self.optimizer = getattr(optim, optimizer_class_name)(self.actor_critic.parameters(), lr=learning_rate)
|
||||
self.transition = RolloutStorage.Transition()
|
||||
|
||||
# Velocity Planner
|
||||
self.velocity_planner = velocity_planner
|
||||
self.velocity_optimizer = getattr(optim, optimizer_class_name)(self.velocity_planner.parameters(), lr=learning_rate)
|
||||
self.lin_vel_x = kwargs.get('lin_vel_x', None)
|
||||
|
||||
# PPO parameters
|
||||
self.clip_param = clip_param
|
||||
self.num_learning_epochs = num_learning_epochs
|
||||
|
@ -98,7 +106,12 @@ class PPO:
|
|||
if self.actor_critic.is_recurrent:
|
||||
self.transition.hidden_states = self.actor_critic.get_hidden_states()
|
||||
# Compute the actions and values
|
||||
self.transition.actions = self.actor_critic.act(obs).detach()
|
||||
vel_obs = torch.cat([obs[:, :9], obs[:, 12:]], dim=1)
|
||||
velocity = self.velocity_planner(vel_obs)
|
||||
if self.lin_vel_x is not None:
|
||||
velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
||||
|
||||
self.transition.actions = self.actor_critic.act(obs, velocity=velocity)[0].detach()
|
||||
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
|
||||
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
|
||||
self.transition.action_mean = self.actor_critic.action_mean.detach()
|
||||
|
@ -106,7 +119,7 @@ class PPO:
|
|||
# need to record obs and critic_obs before env.step()
|
||||
self.transition.observations = obs
|
||||
self.transition.critic_observations = critic_obs
|
||||
return self.transition.actions
|
||||
return self.transition.actions, velocity.squeeze().detach()
|
||||
|
||||
def process_env_step(self, rewards, dones, infos):
|
||||
self.transition.rewards = rewards.clone()
|
||||
|
@ -146,9 +159,11 @@ class PPO:
|
|||
|
||||
# Gradient step
|
||||
self.optimizer.zero_grad()
|
||||
self.velocity_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
|
||||
self.optimizer.step()
|
||||
self.velocity_optimizer.step()
|
||||
|
||||
num_updates = self.num_learning_epochs * self.num_mini_batches
|
||||
for k in mean_losses.keys():
|
||||
|
@ -162,9 +177,16 @@ class PPO:
|
|||
return mean_losses, average_stats
|
||||
|
||||
def compute_losses(self, minibatch):
|
||||
self.actor_critic.act(minibatch.obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0])
|
||||
obs = copy.deepcopy(minibatch.obs)
|
||||
# print(obs.shape)
|
||||
vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
|
||||
# print(vel_obs.shape)
|
||||
velocity = self.velocity_planner(vel_obs)
|
||||
if self.lin_vel_x is not None:
|
||||
velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
||||
self.actor_critic.act(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0], velocity=velocity)
|
||||
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(minibatch.actions)
|
||||
value_batch = self.actor_critic.evaluate(minibatch.critic_obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[1])
|
||||
value_batch = self.actor_critic.evaluate(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[1])
|
||||
mu_batch = self.actor_critic.action_mean
|
||||
sigma_batch = self.actor_critic.action_std
|
||||
try:
|
||||
|
|
|
@ -62,15 +62,6 @@ class ActorCriticRecurrent(ActorCritic):
|
|||
|
||||
activation = get_activation(activation)
|
||||
|
||||
self.velocity_planner = nn.Sequential(
|
||||
nn.Linear(num_actor_obs-3, 256),
|
||||
nn.ELU(),
|
||||
nn.Linear(256, 128),
|
||||
nn.ELU(),
|
||||
nn.Linear(128, 1),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
self.lin_vel_x = kwargs["lin_vel_x"]
|
||||
self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
|
||||
self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
|
||||
|
@ -82,16 +73,19 @@ class ActorCriticRecurrent(ActorCritic):
|
|||
self.memory_a.reset(dones)
|
||||
self.memory_c.reset(dones)
|
||||
|
||||
def act(self, observations, masks=None, hidden_states=None):
|
||||
vel_obs = torch.cat([observations[:, :9], observations[:, 12:]], dim=1)
|
||||
velocity = self.velocity_planner(vel_obs)
|
||||
velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
||||
self.velocity = velocity
|
||||
observations[:, 9] = velocity
|
||||
def act(self, observations, masks=None, hidden_states=None, velocity=None):
|
||||
if velocity is not None:
|
||||
observations[..., 9] = velocity.squeeze()
|
||||
# vel_obs = torch.cat([observations[:, :9], observations[:, 12:]], dim=1)
|
||||
# velocity = self.velocity_planner(vel_obs)
|
||||
# velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
||||
# observations[:, 9] = velocity
|
||||
input_a = self.memory_a(observations, masks, hidden_states)
|
||||
return super().act(input_a.squeeze(0))
|
||||
return super().act(input_a.squeeze(0)), velocity.squeeze().detach()
|
||||
|
||||
def act_inference(self, observations):
|
||||
def act_inference(self, observations, velocity=None):
|
||||
if velocity is not None:
|
||||
observations[:, 9] = velocity
|
||||
input_a = self.memory_a(observations)
|
||||
return super().act_inference(input_a.squeeze(0))
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ import statistics
|
|||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import rsl_rl.algorithms as algorithms
|
||||
import rsl_rl.modules as modules
|
||||
|
@ -61,8 +62,17 @@ class OnPolicyRunner:
|
|||
self.policy_cfg,
|
||||
).to(self.device)
|
||||
|
||||
velocity_planner = nn.Sequential(
|
||||
nn.Linear(env.num_obs-3, 256),
|
||||
nn.ELU(),
|
||||
nn.Linear(256, 128),
|
||||
nn.ELU(),
|
||||
nn.Linear(128, 1)
|
||||
).to(self.device)
|
||||
|
||||
alg_class = getattr(algorithms, self.cfg["algorithm_class_name"]) # PPO
|
||||
self.alg: algorithms.PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg)
|
||||
|
||||
self.alg: algorithms.PPO = alg_class(actor_critic, velocity_planner, device=self.device, **self.alg_cfg)
|
||||
|
||||
self.num_steps_per_env = self.cfg["num_steps_per_env"]
|
||||
self.save_interval = self.cfg["save_interval"]
|
||||
|
@ -141,8 +151,8 @@ class OnPolicyRunner:
|
|||
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))
|
||||
|
||||
def rollout_step(self, obs, critic_obs):
|
||||
actions = self.alg.act(obs, critic_obs)
|
||||
obs, privileged_obs, rewards, dones, infos = self.env.step(actions)
|
||||
actions, velocity = self.alg.act(obs, critic_obs)
|
||||
obs, privileged_obs, rewards, dones, infos = self.env.step(actions, velocity)
|
||||
critic_obs = privileged_obs if privileged_obs is not None else obs
|
||||
obs, critic_obs, rewards, dones = obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device)
|
||||
self.alg.process_env_step(rewards, dones, infos)
|
||||
|
@ -229,6 +239,7 @@ class OnPolicyRunner:
|
|||
run_state_dict = {
|
||||
'model_state_dict': self.alg.actor_critic.state_dict(),
|
||||
'optimizer_state_dict': self.alg.optimizer.state_dict(),
|
||||
'velocity_optimizer_state_dict': self.alg.velocity_optimizer.state_dict(),
|
||||
'iter': self.current_learning_iteration,
|
||||
'infos': infos,
|
||||
}
|
||||
|
@ -240,7 +251,9 @@ class OnPolicyRunner:
|
|||
loaded_dict = torch.load(path)
|
||||
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
|
||||
if load_optimizer and "optimizer_state_dict" in loaded_dict:
|
||||
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'])
|
||||
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'], )
|
||||
if load_optimizer and "velocity_optimizer_state_dict" in loaded_dict:
|
||||
self.alg.velocity_optimizer.load_state_dict(loaded_dict['velocity_optimizer_state_dict'], )
|
||||
if "lr_scheduler_state_dict" in loaded_dict:
|
||||
if not hasattr(self.alg, "lr_scheduler"):
|
||||
print("Warning: lr_scheduler_state_dict found in checkpoint but no lr_scheduler in algorithm. Ignoring.")
|
||||
|
|
Loading…
Reference in New Issue