velocity loss added

This commit is contained in:
Jerry Xu 2024-05-26 16:12:42 -04:00
parent da35cb1cb6
commit 34e6399781
5 changed files with 41 additions and 23 deletions

View File

@ -11,7 +11,7 @@ class A1LeapCfg( A1FieldCfg ):
# latency_range = [0.04-0.0025, 0.04+0.0075] # latency_range = [0.04-0.0025, 0.04+0.0075]
#### uncomment the above to train non-virtual terrain #### uncomment the above to train non-virtual terrain
class env(A1FieldCfg.env): class env(A1FieldCfg.env):
num_envs = 4 num_envs = 4096
class terrain( A1FieldCfg.terrain ): class terrain( A1FieldCfg.terrain ):
max_init_terrain_level = 2 max_init_terrain_level = 2
border_size = 5 border_size = 5
@ -37,7 +37,7 @@ class A1LeapCfg( A1FieldCfg ):
class commands( A1FieldCfg.commands ): class commands( A1FieldCfg.commands ):
class ranges( A1FieldCfg.commands.ranges ): class ranges( A1FieldCfg.commands.ranges ):
lin_vel_x = [1.0, 4.0] lin_vel_x = [1.5, 3.0]
lin_vel_y = [0.0, 0.0] lin_vel_y = [0.0, 0.0]
ang_vel_yaw = [0., 0.] ang_vel_yaw = [0., 0.]
@ -63,11 +63,12 @@ class A1LeapCfg( A1FieldCfg ):
tracking_ang_vel = 0.05 tracking_ang_vel = 0.05
world_vel_l2norm = -1. world_vel_l2norm = -1.
legs_energy_substeps = -1e-6 legs_energy_substeps = -1e-6
alive = 2. alive = 1. # 2.
penetrate_depth = -4e-3 penetrate_depth = -4e-3
penetrate_volume = -4e-3 penetrate_volume = -4e-3
exceed_dof_pos_limits = -1e-1 exceed_dof_pos_limits = -1e-1
exceed_torque_limits_i = -2e-1 exceed_torque_limits_i = -2e-1
lin_pos_x = 1.
# track_predict_vel_l2norm = -1. # track_predict_vel_l2norm = -1.
soft_dof_pos_limit = 0.9 soft_dof_pos_limit = 0.9
@ -82,7 +83,7 @@ class A1LeapCfgPPO( A1FieldCfgPPO ):
class algorithm( A1FieldCfgPPO.algorithm ): class algorithm( A1FieldCfgPPO.algorithm ):
entropy_coef = 0.0 entropy_coef = 0.0
clip_min_std = 0.2 clip_min_std = 0.2
lin_vel_x = [2.0, 3.0] lin_vel_x = [0.5, 3.0]
command_scale = 2.0 command_scale = 2.0
class runner( A1FieldCfgPPO.runner ): class runner( A1FieldCfgPPO.runner ):
@ -101,7 +102,7 @@ class A1LeapCfgPPO( A1FieldCfgPPO ):
resume = True resume = True
# load_run = "{Your traind walking model directory}" # load_run = "{Your traind walking model directory}"
# load_run = "May16_18-12-08_WalkingBase_pEnergySubsteps2e-5_aScale0.5" # load_run = "May16_18-12-08_WalkingBase_pEnergySubsteps2e-5_aScale0.5"
load_run = "High_speed_walk" load_run = "Leap_2m_2500"
# load_run = "May15_21-34-27_Skillleap_pEnergySubsteps-1e-06_virtual"#"May15_17-07-38_WalkingBase_pEnergySubsteps2e-5_aScale0.5" # load_run = "May15_21-34-27_Skillleap_pEnergySubsteps-1e-06_virtual"#"May15_17-07-38_WalkingBase_pEnergySubsteps2e-5_aScale0.5"
# load_run = "{Your virtually trained leap model directory}" # load_run = "{Your virtually trained leap model directory}"
max_iterations = 20000 max_iterations = 20000

View File

@ -1037,7 +1037,9 @@ class LeggedRobotField(LeggedRobot):
world_vel_error = torch.sum(torch.square(self.commands[:, :2] - self.root_states[:, 7:9]), dim= 1) world_vel_error = torch.sum(torch.square(self.commands[:, :2] - self.root_states[:, 7:9]), dim= 1)
return (1 - torch.exp(-world_vel_error/self.cfg.rewards.tracking_sigma)) * engaging_mask # reverse version of tracking reward return (1 - torch.exp(-world_vel_error/self.cfg.rewards.tracking_sigma)) * engaging_mask # reverse version of tracking reward
def _reward_lin_pos_x(self):
return torch.abs((self.root_states[:, :3] - self.env_origins)[:, 0])
##### Some helper functions that override parent class attributes ##### ##### Some helper functions that override parent class attributes #####
@property @property
def all_obs_components(self): def all_obs_components(self):

View File

@ -106,14 +106,14 @@ def play(args):
# "tilt", # "tilt",
] ]
env_cfg.terrain.BarrierTrack_kwargs["leap"] = dict( env_cfg.terrain.BarrierTrack_kwargs["leap"] = dict(
length= (1.3, 1.3), length= (1.5, 1.5),
depth= (0.4, 0.8), depth= (0.4, 0.8),
height= 0.2, height= 0.2,
) )
if "one_obstacle_per_track" in env_cfg.terrain.BarrierTrack_kwargs.keys(): if "one_obstacle_per_track" in env_cfg.terrain.BarrierTrack_kwargs.keys():
env_cfg.terrain.BarrierTrack_kwargs.pop("one_obstacle_per_track") env_cfg.terrain.BarrierTrack_kwargs.pop("one_obstacle_per_track")
env_cfg.terrain.BarrierTrack_kwargs["n_obstacles_per_track"] = 2# 2 env_cfg.terrain.BarrierTrack_kwargs["n_obstacles_per_track"] = 1# 2
env_cfg.commands.ranges.lin_vel_x = [3.0, 3.0] # [1.2, 1.2] env_cfg.commands.ranges.lin_vel_x = [3.0, 3.0] # [1.2, 1.2]
env_cfg.terrain.BarrierTrack_kwargs['track_block_length']= 3. env_cfg.terrain.BarrierTrack_kwargs['track_block_length']= 3.
if "distill" in args.task: if "distill" in args.task:
@ -239,11 +239,11 @@ def play(args):
if "obs_slice" in locals().keys(): if "obs_slice" in locals().keys():
obs_component = obs[:, obs_slice[0]].reshape(-1, *obs_slice[1]) obs_component = obs[:, obs_slice[0]].reshape(-1, *obs_slice[1])
print(obs_component[robot_index]) print(obs_component[robot_index])
vel_obs = torch.cat([obs[:, :9], obs[:, 12:]], dim=1) vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
velocity = velocity_planner(vel_obs) velocity = velocity_planner(vel_obs)
print(velocity) env.commands[..., 0] = velocity.squeeze(-1)
print(env_cfg.commands.ranges.lin_vel_x) obs[..., 9] = velocity.squeeze(-1) * env.obs_scales.lin_vel
velocity = torch.clip(velocity, env_cfg.commands.ranges.lin_vel_x[0], env_cfg.commands.ranges.lin_vel_x[1]) # velocity = torch.clip(velocity, env_cfg.commands.ranges.lin_vel_x[0], env_cfg.commands.ranges.lin_vel_x[1])
actions = policy(obs.detach()) actions = policy(obs.detach())
teacher_actions = actions teacher_actions = actions
obs, critic_obs, rews, dones, infos = env.step(actions.detach(), velocity) obs, critic_obs, rews, dones, infos = env.step(actions.detach(), velocity)

View File

@ -32,6 +32,7 @@ from collections import defaultdict
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import numpy as np
import copy import copy
from rsl_rl.modules import ActorCritic from rsl_rl.modules import ActorCritic
@ -99,9 +100,11 @@ class PPO:
def test_mode(self): def test_mode(self):
self.actor_critic.test() self.actor_critic.test()
self.velocity_planner.eval()
def train_mode(self): def train_mode(self):
self.actor_critic.train() self.actor_critic.train()
self.velocity_planner.train()
def act(self, obs, critic_obs): def act(self, obs, critic_obs):
if self.actor_critic.is_recurrent: if self.actor_critic.is_recurrent:
@ -109,10 +112,10 @@ class PPO:
# Compute the actions and values # Compute the actions and values
vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1) vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
velocity = self.velocity_planner(vel_obs) velocity = self.velocity_planner(vel_obs)
if self.lin_vel_x is not None: # if self.lin_vel_x is not None:
velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1]) # velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
velocity *= self.command_scale self.transition.actions = self.actor_critic.act(obs, velocity=velocity * self.command_scale)[0].detach()
self.transition.actions = self.actor_critic.act(obs, velocity=velocity)[0].detach() critic_obs[..., 9] = velocity.squeeze(-1) * self.command_scale
self.transition.values = self.actor_critic.evaluate(critic_obs).detach() self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach() self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
self.transition.action_mean = self.actor_critic.action_mean.detach() self.transition.action_mean = self.actor_critic.action_mean.detach()
@ -148,7 +151,7 @@ class PPO:
generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
for minibatch in generator: for minibatch in generator:
losses, _, stats = self.compute_losses(minibatch) losses, _, stats = self.compute_losses(minibatch, current_learning_iteration=current_learning_iteration)
loss = 0. loss = 0.
for k, v in losses.items(): for k, v in losses.items():
@ -177,15 +180,15 @@ class PPO:
return mean_losses, average_stats return mean_losses, average_stats
def compute_losses(self, minibatch): def compute_losses(self, minibatch, current_learning_iteration=None):
obs = copy.deepcopy(minibatch.obs) obs = copy.deepcopy(minibatch.obs)
# print(obs.shape)
vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1) vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
# print(vel_obs.shape)
velocity = self.velocity_planner(vel_obs) velocity = self.velocity_planner(vel_obs)
if self.lin_vel_x is not None: # if self.lin_vel_x is not None:
velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1]) # velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
self.actor_critic.act(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0], velocity=velocity) self.actor_critic.act(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0], velocity=velocity * self.command_scale)
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(minibatch.actions) actions_log_prob_batch = self.actor_critic.get_actions_log_prob(minibatch.actions)
value_batch = self.actor_critic.evaluate(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[1]) value_batch = self.actor_critic.evaluate(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[1])
mu_batch = self.actor_critic.action_mean mu_batch = self.actor_critic.action_mean
@ -228,9 +231,17 @@ class PPO:
else: else:
value_loss = (minibatch.returns - value_batch).pow(2).mean() value_loss = (minibatch.returns - value_batch).pow(2).mean()
# Velocity loss
if current_learning_iteration is None:
vel_loss = 0
else:
vel_loss = torch.square(velocity-2).mean() * np.exp(-0.01 * current_learning_iteration + 125)
vel_loss += torch.square(torch.clamp_max(velocity, 1.) - 1).mean()
return_ = dict( return_ = dict(
surrogate_loss= surrogate_loss, surrogate_loss= surrogate_loss,
value_loss= value_loss, value_loss= value_loss,
vel_loss = vel_loss
) )
if entropy_batch is not None: if entropy_batch is not None:
return_["entropy"] = - entropy_batch.mean() return_["entropy"] = - entropy_batch.mean()

View File

@ -221,6 +221,7 @@ class OnPolicyRunner:
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
f"""{'Value function loss:':>{pad}} {locs["losses"]['value_loss']:.4f}\n""" f"""{'Value function loss:':>{pad}} {locs["losses"]['value_loss']:.4f}\n"""
f"""{'Surrogate loss:':>{pad}} {locs["losses"]['surrogate_loss']:.4f}\n""" f"""{'Surrogate loss:':>{pad}} {locs["losses"]['surrogate_loss']:.4f}\n"""
f"""{'Velocity loss:':>{pad}} {locs["losses"]['vel_loss']:.4f}\n"""
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""" f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n""" # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""" # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n"""
@ -238,6 +239,7 @@ class OnPolicyRunner:
def save(self, path, infos=None): def save(self, path, infos=None):
run_state_dict = { run_state_dict = {
'model_state_dict': self.alg.actor_critic.state_dict(), 'model_state_dict': self.alg.actor_critic.state_dict(),
'velocity_planner_state_dict': self.alg.velocity_planner.state_dict(),
'optimizer_state_dict': self.alg.optimizer.state_dict(), 'optimizer_state_dict': self.alg.optimizer.state_dict(),
'velocity_optimizer_state_dict': self.alg.velocity_optimizer.state_dict(), 'velocity_optimizer_state_dict': self.alg.velocity_optimizer.state_dict(),
'iter': self.current_learning_iteration, 'iter': self.current_learning_iteration,
@ -250,6 +252,8 @@ class OnPolicyRunner:
def load(self, path, load_optimizer=True): def load(self, path, load_optimizer=True):
loaded_dict = torch.load(path) loaded_dict = torch.load(path)
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict']) self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
if 'velocity_planner_state_dict' in loaded_dict:
self.alg.velocity_planner.load_state_dict(loaded_dict['velocity_planner_state_dict'])
if load_optimizer and "optimizer_state_dict" in loaded_dict: if load_optimizer and "optimizer_state_dict" in loaded_dict:
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'], ) self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'], )
if load_optimizer and "velocity_optimizer_state_dict" in loaded_dict: if load_optimizer and "velocity_optimizer_state_dict" in loaded_dict: