errors
This commit is contained in:
parent
30c69b3dab
commit
da35cb1cb6
|
@ -10,7 +10,8 @@ class A1LeapCfg( A1FieldCfg ):
|
|||
# delay_action_obs = True
|
||||
# latency_range = [0.04-0.0025, 0.04+0.0075]
|
||||
#### uncomment the above to train non-virtual terrain
|
||||
|
||||
class env(A1FieldCfg.env):
|
||||
num_envs = 4
|
||||
class terrain( A1FieldCfg.terrain ):
|
||||
max_init_terrain_level = 2
|
||||
border_size = 5
|
||||
|
@ -68,6 +69,7 @@ class A1LeapCfg( A1FieldCfg ):
|
|||
exceed_dof_pos_limits = -1e-1
|
||||
exceed_torque_limits_i = -2e-1
|
||||
# track_predict_vel_l2norm = -1.
|
||||
soft_dof_pos_limit = 0.9
|
||||
|
||||
class curriculum( A1FieldCfg.curriculum ):
|
||||
penetrate_volume_threshold_harder = 9000
|
||||
|
@ -80,6 +82,8 @@ class A1LeapCfgPPO( A1FieldCfgPPO ):
|
|||
class algorithm( A1FieldCfgPPO.algorithm ):
|
||||
entropy_coef = 0.0
|
||||
clip_min_std = 0.2
|
||||
lin_vel_x = [2.0, 3.0]
|
||||
command_scale = 2.0
|
||||
|
||||
class runner( A1FieldCfgPPO.runner ):
|
||||
policy_class_name = "ActorCriticRecurrent"
|
||||
|
|
|
@ -77,6 +77,7 @@ class PPO:
|
|||
self.velocity_planner = velocity_planner
|
||||
self.velocity_optimizer = getattr(optim, optimizer_class_name)(self.velocity_planner.parameters(), lr=learning_rate)
|
||||
self.lin_vel_x = kwargs.get('lin_vel_x', None)
|
||||
self.command_scale = kwargs.get('command_scale', 2.0)
|
||||
|
||||
# PPO parameters
|
||||
self.clip_param = clip_param
|
||||
|
@ -106,11 +107,11 @@ class PPO:
|
|||
if self.actor_critic.is_recurrent:
|
||||
self.transition.hidden_states = self.actor_critic.get_hidden_states()
|
||||
# Compute the actions and values
|
||||
vel_obs = torch.cat([obs[:, :9], obs[:, 12:]], dim=1)
|
||||
vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
|
||||
velocity = self.velocity_planner(vel_obs)
|
||||
if self.lin_vel_x is not None:
|
||||
velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
||||
|
||||
velocity *= self.command_scale
|
||||
self.transition.actions = self.actor_critic.act(obs, velocity=velocity)[0].detach()
|
||||
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
|
||||
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
|
||||
|
|
|
@ -75,7 +75,8 @@ class ActorCriticRecurrent(ActorCritic):
|
|||
|
||||
def act(self, observations, masks=None, hidden_states=None, velocity=None):
|
||||
if velocity is not None:
|
||||
observations[..., 9] = velocity.squeeze()
|
||||
# print(velocity.squeeze())
|
||||
observations[..., 9] = velocity.squeeze(-1)
|
||||
# vel_obs = torch.cat([observations[:, :9], observations[:, 12:]], dim=1)
|
||||
# velocity = self.velocity_planner(vel_obs)
|
||||
# velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
|
||||
|
|
Loading…
Reference in New Issue