velocity in ppo
This commit is contained in:
parent
8c183fd9ab
commit
edde9b8f8a
|
@ -30,7 +30,6 @@ class A1LeapCfg( A1FieldCfg ):
|
|||
virtual_terrain= True, # Change this to False for real terrain
|
||||
no_perlin_threshold= 0.06,
|
||||
n_obstacles_curriculum = True,
|
||||
n_obstacles_per_track=2,
|
||||
))
|
||||
|
||||
TerrainPerlin_kwargs = merge_dict(A1FieldCfg.terrain.TerrainPerlin_kwargs, dict(
|
||||
|
@ -82,6 +81,9 @@ class A1LeapCfg( A1FieldCfg ):
|
|||
|
||||
|
||||
class A1LeapCfgPPO( A1FieldCfgPPO ):
|
||||
class policy(A1FieldCfgPPO.policy):
|
||||
num_critic_obs = 78
|
||||
|
||||
class algorithm( A1FieldCfgPPO.algorithm ):
|
||||
entropy_coef = 0.0
|
||||
clip_min_std = 0.2
|
||||
|
@ -104,7 +106,8 @@ class A1LeapCfgPPO( A1FieldCfgPPO ):
|
|||
resume = True
|
||||
# load_run = "{Your traind walking model directory}"
|
||||
# load_run = "May16_18-12-08_WalkingBase_pEnergySubsteps2e-5_aScale0.5"
|
||||
load_run = "Leap_loss_success"
|
||||
# load_run = "Leap_loss_success"
|
||||
load_run = 'High_speed_walk'
|
||||
# load_run = "May15_21-34-27_Skillleap_pEnergySubsteps-1e-06_virtual"#"May15_17-07-38_WalkingBase_pEnergySubsteps2e-5_aScale0.5"
|
||||
# load_run = "{Your virtually trained leap model directory}"
|
||||
max_iterations = 20000
|
||||
|
|
|
@ -115,14 +115,14 @@ class PPO:
|
|||
# if self.lin_vel_x is not None:
|
||||
# velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
||||
self.transition.actions = self.actor_critic.act(obs, velocity=velocity * self.command_scale)[0].detach()
|
||||
critic_obs[..., 9] = velocity.squeeze(-1) * self.command_scale
|
||||
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
|
||||
# critic_obs[..., 9] = velocity.squeeze(-1) * self.command_scale
|
||||
self.transition.values = self.actor_critic.evaluate(vel_obs).detach()
|
||||
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
|
||||
self.transition.action_mean = self.actor_critic.action_mean.detach()
|
||||
self.transition.action_sigma = self.actor_critic.action_std.detach()
|
||||
# need to record obs and critic_obs before env.step()
|
||||
self.transition.observations = obs
|
||||
self.transition.critic_observations = critic_obs
|
||||
self.transition.observations = vel_obs
|
||||
self.transition.critic_observations = vel_obs
|
||||
return self.transition.actions, velocity.squeeze().detach()
|
||||
|
||||
def process_env_step(self, rewards, dones, infos):
|
||||
|
@ -138,6 +138,7 @@ class PPO:
|
|||
self.actor_critic.reset(dones)
|
||||
|
||||
def compute_returns(self, last_critic_obs):
|
||||
last_critic_obs = torch.cat([last_critic_obs[..., :9], last_critic_obs[..., 12:]], dim=-1)
|
||||
last_values= self.actor_critic.evaluate(last_critic_obs).detach()
|
||||
self.storage.compute_returns(last_values, self.gamma, self.lam)
|
||||
|
||||
|
@ -183,12 +184,14 @@ class PPO:
|
|||
def compute_losses(self, minibatch, current_learning_iteration=None):
|
||||
obs = copy.deepcopy(minibatch.obs)
|
||||
|
||||
vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
|
||||
# vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
|
||||
|
||||
velocity = self.velocity_planner(vel_obs)
|
||||
velocity = self.velocity_planner(obs)
|
||||
zeros = torch.zeros_like(velocity, requires_grad=False)
|
||||
obs_with_cmd = torch.cat([obs[..., :9], velocity, zeros, zeros, obs[..., 9:]], dim=-1)
|
||||
# if self.lin_vel_x is not None:
|
||||
# velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
||||
self.actor_critic.act(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0], velocity=velocity * self.command_scale)
|
||||
self.actor_critic.act(obs_with_cmd, masks=minibatch.masks, hidden_states=minibatch.hid_states[0], velocity=velocity * self.command_scale)
|
||||
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(minibatch.actions)
|
||||
value_batch = self.actor_critic.evaluate(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[1])
|
||||
mu_batch = self.actor_critic.action_mean
|
||||
|
|
|
@ -78,7 +78,7 @@ class OnPolicyRunner:
|
|||
self.save_interval = self.cfg["save_interval"]
|
||||
|
||||
# init storage and model
|
||||
self.alg.init_storage(self.env.num_envs, self.num_steps_per_env, [self.env.num_obs], [self.env.num_privileged_obs], [self.env.num_actions])
|
||||
self.alg.init_storage(self.env.num_envs, self.num_steps_per_env, [self.env.num_obs - 3], [self.env.num_privileged_obs], [self.env.num_actions])
|
||||
|
||||
# Log
|
||||
self.log_dir = log_dir
|
||||
|
@ -251,11 +251,22 @@ class OnPolicyRunner:
|
|||
|
||||
def load(self, path, load_optimizer=True):
|
||||
loaded_dict = torch.load(path)
|
||||
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
|
||||
|
||||
try:
|
||||
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'], strict=False)
|
||||
if load_optimizer and "optimizer_state_dict" in loaded_dict:
|
||||
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'], )
|
||||
except:
|
||||
from collections import OrderedDict
|
||||
new_state = OrderedDict()
|
||||
for k, v in loaded_dict['model_state_dict'].items():
|
||||
if "memory_c" in k or "critic" in k:
|
||||
continue
|
||||
new_state[k] = v
|
||||
self.alg.actor_critic.load_state_dict(new_state, strict=False)
|
||||
|
||||
if 'velocity_planner_state_dict' in loaded_dict:
|
||||
self.alg.velocity_planner.load_state_dict(loaded_dict['velocity_planner_state_dict'])
|
||||
if load_optimizer and "optimizer_state_dict" in loaded_dict:
|
||||
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'], )
|
||||
if load_optimizer and "velocity_optimizer_state_dict" in loaded_dict:
|
||||
self.alg.velocity_optimizer.load_state_dict(loaded_dict['velocity_optimizer_state_dict'], )
|
||||
if "lr_scheduler_state_dict" in loaded_dict:
|
||||
|
|
Loading…
Reference in New Issue