velocity in ppo
This commit is contained in:
parent
8c183fd9ab
commit
edde9b8f8a
|
@ -30,7 +30,6 @@ class A1LeapCfg( A1FieldCfg ):
|
||||||
virtual_terrain= True, # Change this to False for real terrain
|
virtual_terrain= True, # Change this to False for real terrain
|
||||||
no_perlin_threshold= 0.06,
|
no_perlin_threshold= 0.06,
|
||||||
n_obstacles_curriculum = True,
|
n_obstacles_curriculum = True,
|
||||||
n_obstacles_per_track=2,
|
|
||||||
))
|
))
|
||||||
|
|
||||||
TerrainPerlin_kwargs = merge_dict(A1FieldCfg.terrain.TerrainPerlin_kwargs, dict(
|
TerrainPerlin_kwargs = merge_dict(A1FieldCfg.terrain.TerrainPerlin_kwargs, dict(
|
||||||
|
@ -82,6 +81,9 @@ class A1LeapCfg( A1FieldCfg ):
|
||||||
|
|
||||||
|
|
||||||
class A1LeapCfgPPO( A1FieldCfgPPO ):
|
class A1LeapCfgPPO( A1FieldCfgPPO ):
|
||||||
|
class policy(A1FieldCfgPPO.policy):
|
||||||
|
num_critic_obs = 78
|
||||||
|
|
||||||
class algorithm( A1FieldCfgPPO.algorithm ):
|
class algorithm( A1FieldCfgPPO.algorithm ):
|
||||||
entropy_coef = 0.0
|
entropy_coef = 0.0
|
||||||
clip_min_std = 0.2
|
clip_min_std = 0.2
|
||||||
|
@ -104,7 +106,8 @@ class A1LeapCfgPPO( A1FieldCfgPPO ):
|
||||||
resume = True
|
resume = True
|
||||||
# load_run = "{Your traind walking model directory}"
|
# load_run = "{Your traind walking model directory}"
|
||||||
# load_run = "May16_18-12-08_WalkingBase_pEnergySubsteps2e-5_aScale0.5"
|
# load_run = "May16_18-12-08_WalkingBase_pEnergySubsteps2e-5_aScale0.5"
|
||||||
load_run = "Leap_loss_success"
|
# load_run = "Leap_loss_success"
|
||||||
|
load_run = 'High_speed_walk'
|
||||||
# load_run = "May15_21-34-27_Skillleap_pEnergySubsteps-1e-06_virtual"#"May15_17-07-38_WalkingBase_pEnergySubsteps2e-5_aScale0.5"
|
# load_run = "May15_21-34-27_Skillleap_pEnergySubsteps-1e-06_virtual"#"May15_17-07-38_WalkingBase_pEnergySubsteps2e-5_aScale0.5"
|
||||||
# load_run = "{Your virtually trained leap model directory}"
|
# load_run = "{Your virtually trained leap model directory}"
|
||||||
max_iterations = 20000
|
max_iterations = 20000
|
||||||
|
|
|
@ -115,14 +115,14 @@ class PPO:
|
||||||
# if self.lin_vel_x is not None:
|
# if self.lin_vel_x is not None:
|
||||||
# velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
# velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
||||||
self.transition.actions = self.actor_critic.act(obs, velocity=velocity * self.command_scale)[0].detach()
|
self.transition.actions = self.actor_critic.act(obs, velocity=velocity * self.command_scale)[0].detach()
|
||||||
critic_obs[..., 9] = velocity.squeeze(-1) * self.command_scale
|
# critic_obs[..., 9] = velocity.squeeze(-1) * self.command_scale
|
||||||
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
|
self.transition.values = self.actor_critic.evaluate(vel_obs).detach()
|
||||||
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
|
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
|
||||||
self.transition.action_mean = self.actor_critic.action_mean.detach()
|
self.transition.action_mean = self.actor_critic.action_mean.detach()
|
||||||
self.transition.action_sigma = self.actor_critic.action_std.detach()
|
self.transition.action_sigma = self.actor_critic.action_std.detach()
|
||||||
# need to record obs and critic_obs before env.step()
|
# need to record obs and critic_obs before env.step()
|
||||||
self.transition.observations = obs
|
self.transition.observations = vel_obs
|
||||||
self.transition.critic_observations = critic_obs
|
self.transition.critic_observations = vel_obs
|
||||||
return self.transition.actions, velocity.squeeze().detach()
|
return self.transition.actions, velocity.squeeze().detach()
|
||||||
|
|
||||||
def process_env_step(self, rewards, dones, infos):
|
def process_env_step(self, rewards, dones, infos):
|
||||||
|
@ -138,6 +138,7 @@ class PPO:
|
||||||
self.actor_critic.reset(dones)
|
self.actor_critic.reset(dones)
|
||||||
|
|
||||||
def compute_returns(self, last_critic_obs):
|
def compute_returns(self, last_critic_obs):
|
||||||
|
last_critic_obs = torch.cat([last_critic_obs[..., :9], last_critic_obs[..., 12:]], dim=-1)
|
||||||
last_values= self.actor_critic.evaluate(last_critic_obs).detach()
|
last_values= self.actor_critic.evaluate(last_critic_obs).detach()
|
||||||
self.storage.compute_returns(last_values, self.gamma, self.lam)
|
self.storage.compute_returns(last_values, self.gamma, self.lam)
|
||||||
|
|
||||||
|
@ -183,12 +184,14 @@ class PPO:
|
||||||
def compute_losses(self, minibatch, current_learning_iteration=None):
|
def compute_losses(self, minibatch, current_learning_iteration=None):
|
||||||
obs = copy.deepcopy(minibatch.obs)
|
obs = copy.deepcopy(minibatch.obs)
|
||||||
|
|
||||||
vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
|
# vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
|
||||||
|
|
||||||
velocity = self.velocity_planner(vel_obs)
|
velocity = self.velocity_planner(obs)
|
||||||
|
zeros = torch.zeros_like(velocity, requires_grad=False)
|
||||||
|
obs_with_cmd = torch.cat([obs[..., :9], velocity, zeros, zeros, obs[..., 9:]], dim=-1)
|
||||||
# if self.lin_vel_x is not None:
|
# if self.lin_vel_x is not None:
|
||||||
# velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
# velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
||||||
self.actor_critic.act(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0], velocity=velocity * self.command_scale)
|
self.actor_critic.act(obs_with_cmd, masks=minibatch.masks, hidden_states=minibatch.hid_states[0], velocity=velocity * self.command_scale)
|
||||||
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(minibatch.actions)
|
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(minibatch.actions)
|
||||||
value_batch = self.actor_critic.evaluate(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[1])
|
value_batch = self.actor_critic.evaluate(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[1])
|
||||||
mu_batch = self.actor_critic.action_mean
|
mu_batch = self.actor_critic.action_mean
|
||||||
|
|
|
@ -78,7 +78,7 @@ class OnPolicyRunner:
|
||||||
self.save_interval = self.cfg["save_interval"]
|
self.save_interval = self.cfg["save_interval"]
|
||||||
|
|
||||||
# init storage and model
|
# init storage and model
|
||||||
self.alg.init_storage(self.env.num_envs, self.num_steps_per_env, [self.env.num_obs], [self.env.num_privileged_obs], [self.env.num_actions])
|
self.alg.init_storage(self.env.num_envs, self.num_steps_per_env, [self.env.num_obs - 3], [self.env.num_privileged_obs], [self.env.num_actions])
|
||||||
|
|
||||||
# Log
|
# Log
|
||||||
self.log_dir = log_dir
|
self.log_dir = log_dir
|
||||||
|
@ -251,11 +251,22 @@ class OnPolicyRunner:
|
||||||
|
|
||||||
def load(self, path, load_optimizer=True):
|
def load(self, path, load_optimizer=True):
|
||||||
loaded_dict = torch.load(path)
|
loaded_dict = torch.load(path)
|
||||||
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
|
|
||||||
|
try:
|
||||||
|
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'], strict=False)
|
||||||
|
if load_optimizer and "optimizer_state_dict" in loaded_dict:
|
||||||
|
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'], )
|
||||||
|
except:
|
||||||
|
from collections import OrderedDict
|
||||||
|
new_state = OrderedDict()
|
||||||
|
for k, v in loaded_dict['model_state_dict'].items():
|
||||||
|
if "memory_c" in k or "critic" in k:
|
||||||
|
continue
|
||||||
|
new_state[k] = v
|
||||||
|
self.alg.actor_critic.load_state_dict(new_state, strict=False)
|
||||||
|
|
||||||
if 'velocity_planner_state_dict' in loaded_dict:
|
if 'velocity_planner_state_dict' in loaded_dict:
|
||||||
self.alg.velocity_planner.load_state_dict(loaded_dict['velocity_planner_state_dict'])
|
self.alg.velocity_planner.load_state_dict(loaded_dict['velocity_planner_state_dict'])
|
||||||
if load_optimizer and "optimizer_state_dict" in loaded_dict:
|
|
||||||
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'], )
|
|
||||||
if load_optimizer and "velocity_optimizer_state_dict" in loaded_dict:
|
if load_optimizer and "velocity_optimizer_state_dict" in loaded_dict:
|
||||||
self.alg.velocity_optimizer.load_state_dict(loaded_dict['velocity_optimizer_state_dict'], )
|
self.alg.velocity_optimizer.load_state_dict(loaded_dict['velocity_optimizer_state_dict'], )
|
||||||
if "lr_scheduler_state_dict" in loaded_dict:
|
if "lr_scheduler_state_dict" in loaded_dict:
|
||||||
|
|
Loading…
Reference in New Issue