velocity loss added
This commit is contained in:
parent
da35cb1cb6
commit
34e6399781
|
@ -11,7 +11,7 @@ class A1LeapCfg( A1FieldCfg ):
|
||||||
# latency_range = [0.04-0.0025, 0.04+0.0075]
|
# latency_range = [0.04-0.0025, 0.04+0.0075]
|
||||||
#### uncomment the above to train non-virtual terrain
|
#### uncomment the above to train non-virtual terrain
|
||||||
class env(A1FieldCfg.env):
|
class env(A1FieldCfg.env):
|
||||||
num_envs = 4
|
num_envs = 4096
|
||||||
class terrain( A1FieldCfg.terrain ):
|
class terrain( A1FieldCfg.terrain ):
|
||||||
max_init_terrain_level = 2
|
max_init_terrain_level = 2
|
||||||
border_size = 5
|
border_size = 5
|
||||||
|
@ -37,7 +37,7 @@ class A1LeapCfg( A1FieldCfg ):
|
||||||
|
|
||||||
class commands( A1FieldCfg.commands ):
|
class commands( A1FieldCfg.commands ):
|
||||||
class ranges( A1FieldCfg.commands.ranges ):
|
class ranges( A1FieldCfg.commands.ranges ):
|
||||||
lin_vel_x = [1.0, 4.0]
|
lin_vel_x = [1.5, 3.0]
|
||||||
lin_vel_y = [0.0, 0.0]
|
lin_vel_y = [0.0, 0.0]
|
||||||
ang_vel_yaw = [0., 0.]
|
ang_vel_yaw = [0., 0.]
|
||||||
|
|
||||||
|
@ -63,11 +63,12 @@ class A1LeapCfg( A1FieldCfg ):
|
||||||
tracking_ang_vel = 0.05
|
tracking_ang_vel = 0.05
|
||||||
world_vel_l2norm = -1.
|
world_vel_l2norm = -1.
|
||||||
legs_energy_substeps = -1e-6
|
legs_energy_substeps = -1e-6
|
||||||
alive = 2.
|
alive = 1. # 2.
|
||||||
penetrate_depth = -4e-3
|
penetrate_depth = -4e-3
|
||||||
penetrate_volume = -4e-3
|
penetrate_volume = -4e-3
|
||||||
exceed_dof_pos_limits = -1e-1
|
exceed_dof_pos_limits = -1e-1
|
||||||
exceed_torque_limits_i = -2e-1
|
exceed_torque_limits_i = -2e-1
|
||||||
|
lin_pos_x = 1.
|
||||||
# track_predict_vel_l2norm = -1.
|
# track_predict_vel_l2norm = -1.
|
||||||
soft_dof_pos_limit = 0.9
|
soft_dof_pos_limit = 0.9
|
||||||
|
|
||||||
|
@ -82,7 +83,7 @@ class A1LeapCfgPPO( A1FieldCfgPPO ):
|
||||||
class algorithm( A1FieldCfgPPO.algorithm ):
|
class algorithm( A1FieldCfgPPO.algorithm ):
|
||||||
entropy_coef = 0.0
|
entropy_coef = 0.0
|
||||||
clip_min_std = 0.2
|
clip_min_std = 0.2
|
||||||
lin_vel_x = [2.0, 3.0]
|
lin_vel_x = [0.5, 3.0]
|
||||||
command_scale = 2.0
|
command_scale = 2.0
|
||||||
|
|
||||||
class runner( A1FieldCfgPPO.runner ):
|
class runner( A1FieldCfgPPO.runner ):
|
||||||
|
@ -101,7 +102,7 @@ class A1LeapCfgPPO( A1FieldCfgPPO ):
|
||||||
resume = True
|
resume = True
|
||||||
# load_run = "{Your traind walking model directory}"
|
# load_run = "{Your traind walking model directory}"
|
||||||
# load_run = "May16_18-12-08_WalkingBase_pEnergySubsteps2e-5_aScale0.5"
|
# load_run = "May16_18-12-08_WalkingBase_pEnergySubsteps2e-5_aScale0.5"
|
||||||
load_run = "High_speed_walk"
|
load_run = "Leap_2m_2500"
|
||||||
# load_run = "May15_21-34-27_Skillleap_pEnergySubsteps-1e-06_virtual"#"May15_17-07-38_WalkingBase_pEnergySubsteps2e-5_aScale0.5"
|
# load_run = "May15_21-34-27_Skillleap_pEnergySubsteps-1e-06_virtual"#"May15_17-07-38_WalkingBase_pEnergySubsteps2e-5_aScale0.5"
|
||||||
# load_run = "{Your virtually trained leap model directory}"
|
# load_run = "{Your virtually trained leap model directory}"
|
||||||
max_iterations = 20000
|
max_iterations = 20000
|
||||||
|
|
|
@ -1037,7 +1037,9 @@ class LeggedRobotField(LeggedRobot):
|
||||||
world_vel_error = torch.sum(torch.square(self.commands[:, :2] - self.root_states[:, 7:9]), dim= 1)
|
world_vel_error = torch.sum(torch.square(self.commands[:, :2] - self.root_states[:, 7:9]), dim= 1)
|
||||||
return (1 - torch.exp(-world_vel_error/self.cfg.rewards.tracking_sigma)) * engaging_mask # reverse version of tracking reward
|
return (1 - torch.exp(-world_vel_error/self.cfg.rewards.tracking_sigma)) * engaging_mask # reverse version of tracking reward
|
||||||
|
|
||||||
|
def _reward_lin_pos_x(self):
|
||||||
|
return torch.abs((self.root_states[:, :3] - self.env_origins)[:, 0])
|
||||||
|
|
||||||
##### Some helper functions that override parent class attributes #####
|
##### Some helper functions that override parent class attributes #####
|
||||||
@property
|
@property
|
||||||
def all_obs_components(self):
|
def all_obs_components(self):
|
||||||
|
|
|
@ -106,14 +106,14 @@ def play(args):
|
||||||
# "tilt",
|
# "tilt",
|
||||||
]
|
]
|
||||||
env_cfg.terrain.BarrierTrack_kwargs["leap"] = dict(
|
env_cfg.terrain.BarrierTrack_kwargs["leap"] = dict(
|
||||||
length= (1.3, 1.3),
|
length= (1.5, 1.5),
|
||||||
depth= (0.4, 0.8),
|
depth= (0.4, 0.8),
|
||||||
height= 0.2,
|
height= 0.2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "one_obstacle_per_track" in env_cfg.terrain.BarrierTrack_kwargs.keys():
|
if "one_obstacle_per_track" in env_cfg.terrain.BarrierTrack_kwargs.keys():
|
||||||
env_cfg.terrain.BarrierTrack_kwargs.pop("one_obstacle_per_track")
|
env_cfg.terrain.BarrierTrack_kwargs.pop("one_obstacle_per_track")
|
||||||
env_cfg.terrain.BarrierTrack_kwargs["n_obstacles_per_track"] = 2# 2
|
env_cfg.terrain.BarrierTrack_kwargs["n_obstacles_per_track"] = 1# 2
|
||||||
env_cfg.commands.ranges.lin_vel_x = [3.0, 3.0] # [1.2, 1.2]
|
env_cfg.commands.ranges.lin_vel_x = [3.0, 3.0] # [1.2, 1.2]
|
||||||
env_cfg.terrain.BarrierTrack_kwargs['track_block_length']= 3.
|
env_cfg.terrain.BarrierTrack_kwargs['track_block_length']= 3.
|
||||||
if "distill" in args.task:
|
if "distill" in args.task:
|
||||||
|
@ -239,11 +239,11 @@ def play(args):
|
||||||
if "obs_slice" in locals().keys():
|
if "obs_slice" in locals().keys():
|
||||||
obs_component = obs[:, obs_slice[0]].reshape(-1, *obs_slice[1])
|
obs_component = obs[:, obs_slice[0]].reshape(-1, *obs_slice[1])
|
||||||
print(obs_component[robot_index])
|
print(obs_component[robot_index])
|
||||||
vel_obs = torch.cat([obs[:, :9], obs[:, 12:]], dim=1)
|
vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
|
||||||
velocity = velocity_planner(vel_obs)
|
velocity = velocity_planner(vel_obs)
|
||||||
print(velocity)
|
env.commands[..., 0] = velocity.squeeze(-1)
|
||||||
print(env_cfg.commands.ranges.lin_vel_x)
|
obs[..., 9] = velocity.squeeze(-1) * env.obs_scales.lin_vel
|
||||||
velocity = torch.clip(velocity, env_cfg.commands.ranges.lin_vel_x[0], env_cfg.commands.ranges.lin_vel_x[1])
|
# velocity = torch.clip(velocity, env_cfg.commands.ranges.lin_vel_x[0], env_cfg.commands.ranges.lin_vel_x[1])
|
||||||
actions = policy(obs.detach())
|
actions = policy(obs.detach())
|
||||||
teacher_actions = actions
|
teacher_actions = actions
|
||||||
obs, critic_obs, rews, dones, infos = env.step(actions.detach(), velocity)
|
obs, critic_obs, rews, dones, infos = env.step(actions.detach(), velocity)
|
||||||
|
|
|
@ -32,6 +32,7 @@ from collections import defaultdict
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
import numpy as np
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
from rsl_rl.modules import ActorCritic
|
from rsl_rl.modules import ActorCritic
|
||||||
|
@ -99,9 +100,11 @@ class PPO:
|
||||||
|
|
||||||
def test_mode(self):
|
def test_mode(self):
|
||||||
self.actor_critic.test()
|
self.actor_critic.test()
|
||||||
|
self.velocity_planner.eval()
|
||||||
|
|
||||||
def train_mode(self):
|
def train_mode(self):
|
||||||
self.actor_critic.train()
|
self.actor_critic.train()
|
||||||
|
self.velocity_planner.train()
|
||||||
|
|
||||||
def act(self, obs, critic_obs):
|
def act(self, obs, critic_obs):
|
||||||
if self.actor_critic.is_recurrent:
|
if self.actor_critic.is_recurrent:
|
||||||
|
@ -109,10 +112,10 @@ class PPO:
|
||||||
# Compute the actions and values
|
# Compute the actions and values
|
||||||
vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
|
vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
|
||||||
velocity = self.velocity_planner(vel_obs)
|
velocity = self.velocity_planner(vel_obs)
|
||||||
if self.lin_vel_x is not None:
|
# if self.lin_vel_x is not None:
|
||||||
velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
# velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
||||||
velocity *= self.command_scale
|
self.transition.actions = self.actor_critic.act(obs, velocity=velocity * self.command_scale)[0].detach()
|
||||||
self.transition.actions = self.actor_critic.act(obs, velocity=velocity)[0].detach()
|
critic_obs[..., 9] = velocity.squeeze(-1) * self.command_scale
|
||||||
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
|
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
|
||||||
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
|
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
|
||||||
self.transition.action_mean = self.actor_critic.action_mean.detach()
|
self.transition.action_mean = self.actor_critic.action_mean.detach()
|
||||||
|
@ -148,7 +151,7 @@ class PPO:
|
||||||
generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
|
generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
|
||||||
for minibatch in generator:
|
for minibatch in generator:
|
||||||
|
|
||||||
losses, _, stats = self.compute_losses(minibatch)
|
losses, _, stats = self.compute_losses(minibatch, current_learning_iteration=current_learning_iteration)
|
||||||
|
|
||||||
loss = 0.
|
loss = 0.
|
||||||
for k, v in losses.items():
|
for k, v in losses.items():
|
||||||
|
@ -177,15 +180,15 @@ class PPO:
|
||||||
|
|
||||||
return mean_losses, average_stats
|
return mean_losses, average_stats
|
||||||
|
|
||||||
def compute_losses(self, minibatch):
|
def compute_losses(self, minibatch, current_learning_iteration=None):
|
||||||
obs = copy.deepcopy(minibatch.obs)
|
obs = copy.deepcopy(minibatch.obs)
|
||||||
# print(obs.shape)
|
|
||||||
vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
|
vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
|
||||||
# print(vel_obs.shape)
|
|
||||||
velocity = self.velocity_planner(vel_obs)
|
velocity = self.velocity_planner(vel_obs)
|
||||||
if self.lin_vel_x is not None:
|
# if self.lin_vel_x is not None:
|
||||||
velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
# velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
||||||
self.actor_critic.act(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0], velocity=velocity)
|
self.actor_critic.act(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0], velocity=velocity * self.command_scale)
|
||||||
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(minibatch.actions)
|
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(minibatch.actions)
|
||||||
value_batch = self.actor_critic.evaluate(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[1])
|
value_batch = self.actor_critic.evaluate(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[1])
|
||||||
mu_batch = self.actor_critic.action_mean
|
mu_batch = self.actor_critic.action_mean
|
||||||
|
@ -228,9 +231,17 @@ class PPO:
|
||||||
else:
|
else:
|
||||||
value_loss = (minibatch.returns - value_batch).pow(2).mean()
|
value_loss = (minibatch.returns - value_batch).pow(2).mean()
|
||||||
|
|
||||||
|
# Velocity loss
|
||||||
|
if current_learning_iteration is None:
|
||||||
|
vel_loss = 0
|
||||||
|
else:
|
||||||
|
vel_loss = torch.square(velocity-2).mean() * np.exp(-0.01 * current_learning_iteration + 125)
|
||||||
|
vel_loss += torch.square(torch.clamp_max(velocity, 1.) - 1).mean()
|
||||||
|
|
||||||
return_ = dict(
|
return_ = dict(
|
||||||
surrogate_loss= surrogate_loss,
|
surrogate_loss= surrogate_loss,
|
||||||
value_loss= value_loss,
|
value_loss= value_loss,
|
||||||
|
vel_loss = vel_loss
|
||||||
)
|
)
|
||||||
if entropy_batch is not None:
|
if entropy_batch is not None:
|
||||||
return_["entropy"] = - entropy_batch.mean()
|
return_["entropy"] = - entropy_batch.mean()
|
||||||
|
|
|
@ -221,6 +221,7 @@ class OnPolicyRunner:
|
||||||
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
|
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
|
||||||
f"""{'Value function loss:':>{pad}} {locs["losses"]['value_loss']:.4f}\n"""
|
f"""{'Value function loss:':>{pad}} {locs["losses"]['value_loss']:.4f}\n"""
|
||||||
f"""{'Surrogate loss:':>{pad}} {locs["losses"]['surrogate_loss']:.4f}\n"""
|
f"""{'Surrogate loss:':>{pad}} {locs["losses"]['surrogate_loss']:.4f}\n"""
|
||||||
|
f"""{'Velocity loss:':>{pad}} {locs["losses"]['vel_loss']:.4f}\n"""
|
||||||
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
|
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
|
||||||
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
|
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
|
||||||
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n"""
|
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n"""
|
||||||
|
@ -238,6 +239,7 @@ class OnPolicyRunner:
|
||||||
def save(self, path, infos=None):
|
def save(self, path, infos=None):
|
||||||
run_state_dict = {
|
run_state_dict = {
|
||||||
'model_state_dict': self.alg.actor_critic.state_dict(),
|
'model_state_dict': self.alg.actor_critic.state_dict(),
|
||||||
|
'velocity_planner_state_dict': self.alg.velocity_planner.state_dict(),
|
||||||
'optimizer_state_dict': self.alg.optimizer.state_dict(),
|
'optimizer_state_dict': self.alg.optimizer.state_dict(),
|
||||||
'velocity_optimizer_state_dict': self.alg.velocity_optimizer.state_dict(),
|
'velocity_optimizer_state_dict': self.alg.velocity_optimizer.state_dict(),
|
||||||
'iter': self.current_learning_iteration,
|
'iter': self.current_learning_iteration,
|
||||||
|
@ -250,6 +252,8 @@ class OnPolicyRunner:
|
||||||
def load(self, path, load_optimizer=True):
|
def load(self, path, load_optimizer=True):
|
||||||
loaded_dict = torch.load(path)
|
loaded_dict = torch.load(path)
|
||||||
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
|
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
|
||||||
|
if 'velocity_planner_state_dict' in loaded_dict:
|
||||||
|
self.alg.velocity_planner.load_state_dict(loaded_dict['velocity_planner_state_dict'])
|
||||||
if load_optimizer and "optimizer_state_dict" in loaded_dict:
|
if load_optimizer and "optimizer_state_dict" in loaded_dict:
|
||||||
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'], )
|
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'], )
|
||||||
if load_optimizer and "velocity_optimizer_state_dict" in loaded_dict:
|
if load_optimizer and "velocity_optimizer_state_dict" in loaded_dict:
|
||||||
|
|
Loading…
Reference in New Issue