velocity in ppo

This commit is contained in:
Jerry Xu 2024-06-02 16:05:57 -04:00
parent 8c183fd9ab
commit edde9b8f8a
3 changed files with 30 additions and 13 deletions

View File

@ -30,7 +30,6 @@ class A1LeapCfg( A1FieldCfg ):
virtual_terrain= True, # Change this to False for real terrain
no_perlin_threshold= 0.06,
n_obstacles_curriculum = True,
n_obstacles_per_track=2,
))
TerrainPerlin_kwargs = merge_dict(A1FieldCfg.terrain.TerrainPerlin_kwargs, dict(
@ -82,6 +81,9 @@ class A1LeapCfg( A1FieldCfg ):
class A1LeapCfgPPO( A1FieldCfgPPO ):
class policy(A1FieldCfgPPO.policy):
num_critic_obs = 78
class algorithm( A1FieldCfgPPO.algorithm ):
entropy_coef = 0.0
clip_min_std = 0.2
@ -104,7 +106,8 @@ class A1LeapCfgPPO( A1FieldCfgPPO ):
resume = True
# load_run = "{Your traind walking model directory}"
# load_run = "May16_18-12-08_WalkingBase_pEnergySubsteps2e-5_aScale0.5"
load_run = "Leap_loss_success"
# load_run = "Leap_loss_success"
load_run = 'High_speed_walk'
# load_run = "May15_21-34-27_Skillleap_pEnergySubsteps-1e-06_virtual"#"May15_17-07-38_WalkingBase_pEnergySubsteps2e-5_aScale0.5"
# load_run = "{Your virtually trained leap model directory}"
max_iterations = 20000

View File

@ -115,14 +115,14 @@ class PPO:
# if self.lin_vel_x is not None:
# velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
self.transition.actions = self.actor_critic.act(obs, velocity=velocity * self.command_scale)[0].detach()
critic_obs[..., 9] = velocity.squeeze(-1) * self.command_scale
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
# critic_obs[..., 9] = velocity.squeeze(-1) * self.command_scale
self.transition.values = self.actor_critic.evaluate(vel_obs).detach()
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
self.transition.action_mean = self.actor_critic.action_mean.detach()
self.transition.action_sigma = self.actor_critic.action_std.detach()
# need to record obs and critic_obs before env.step()
self.transition.observations = obs
self.transition.critic_observations = critic_obs
self.transition.observations = vel_obs
self.transition.critic_observations = vel_obs
return self.transition.actions, velocity.squeeze().detach()
def process_env_step(self, rewards, dones, infos):
@ -138,6 +138,7 @@ class PPO:
self.actor_critic.reset(dones)
def compute_returns(self, last_critic_obs):
last_critic_obs = torch.cat([last_critic_obs[..., :9], last_critic_obs[..., 12:]], dim=-1)
last_values= self.actor_critic.evaluate(last_critic_obs).detach()
self.storage.compute_returns(last_values, self.gamma, self.lam)
@ -183,12 +184,14 @@ class PPO:
def compute_losses(self, minibatch, current_learning_iteration=None):
obs = copy.deepcopy(minibatch.obs)
vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
# vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
velocity = self.velocity_planner(vel_obs)
velocity = self.velocity_planner(obs)
zeros = torch.zeros_like(velocity, requires_grad=False)
obs_with_cmd = torch.cat([obs[..., :9], velocity, zeros, zeros, obs[..., 9:]], dim=-1)
# if self.lin_vel_x is not None:
# velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
self.actor_critic.act(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0], velocity=velocity * self.command_scale)
self.actor_critic.act(obs_with_cmd, masks=minibatch.masks, hidden_states=minibatch.hid_states[0], velocity=velocity * self.command_scale)
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(minibatch.actions)
value_batch = self.actor_critic.evaluate(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[1])
mu_batch = self.actor_critic.action_mean

View File

@ -78,7 +78,7 @@ class OnPolicyRunner:
self.save_interval = self.cfg["save_interval"]
# init storage and model
self.alg.init_storage(self.env.num_envs, self.num_steps_per_env, [self.env.num_obs], [self.env.num_privileged_obs], [self.env.num_actions])
self.alg.init_storage(self.env.num_envs, self.num_steps_per_env, [self.env.num_obs - 3], [self.env.num_privileged_obs], [self.env.num_actions])
# Log
self.log_dir = log_dir
@ -251,11 +251,22 @@ class OnPolicyRunner:
def load(self, path, load_optimizer=True):
loaded_dict = torch.load(path)
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
try:
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'], strict=False)
if load_optimizer and "optimizer_state_dict" in loaded_dict:
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'], )
except:
from collections import OrderedDict
new_state = OrderedDict()
for k, v in loaded_dict['model_state_dict'].items():
if "memory_c" in k or "critic" in k:
continue
new_state[k] = v
self.alg.actor_critic.load_state_dict(new_state, strict=False)
if 'velocity_planner_state_dict' in loaded_dict:
self.alg.velocity_planner.load_state_dict(loaded_dict['velocity_planner_state_dict'])
if load_optimizer and "optimizer_state_dict" in loaded_dict:
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'], )
if load_optimizer and "velocity_optimizer_state_dict" in loaded_dict:
self.alg.velocity_optimizer.load_state_dict(loaded_dict['velocity_optimizer_state_dict'], )
if "lr_scheduler_state_dict" in loaded_dict: