This commit is contained in:
Jerry Xu 2024-05-24 15:58:48 -04:00
parent 30c69b3dab
commit da35cb1cb6
3 changed files with 10 additions and 4 deletions

View File

@ -10,7 +10,8 @@ class A1LeapCfg( A1FieldCfg ):
# delay_action_obs = True
# latency_range = [0.04-0.0025, 0.04+0.0075]
#### uncomment the above to train non-virtual terrain
class env(A1FieldCfg.env):
num_envs = 4
class terrain( A1FieldCfg.terrain ):
max_init_terrain_level = 2
border_size = 5
@ -68,6 +69,7 @@ class A1LeapCfg( A1FieldCfg ):
exceed_dof_pos_limits = -1e-1
exceed_torque_limits_i = -2e-1
# track_predict_vel_l2norm = -1.
soft_dof_pos_limit = 0.9
class curriculum( A1FieldCfg.curriculum ):
penetrate_volume_threshold_harder = 9000
@ -80,6 +82,8 @@ class A1LeapCfgPPO( A1FieldCfgPPO ):
class algorithm( A1FieldCfgPPO.algorithm ):
entropy_coef = 0.0
clip_min_std = 0.2
lin_vel_x = [2.0, 3.0]
command_scale = 2.0
class runner( A1FieldCfgPPO.runner ):
policy_class_name = "ActorCriticRecurrent"

View File

@ -77,6 +77,7 @@ class PPO:
self.velocity_planner = velocity_planner
self.velocity_optimizer = getattr(optim, optimizer_class_name)(self.velocity_planner.parameters(), lr=learning_rate)
self.lin_vel_x = kwargs.get('lin_vel_x', None)
self.command_scale = kwargs.get('command_scale', 2.0)
# PPO parameters
self.clip_param = clip_param
@ -106,11 +107,11 @@ class PPO:
if self.actor_critic.is_recurrent:
self.transition.hidden_states = self.actor_critic.get_hidden_states()
# Compute the actions and values
vel_obs = torch.cat([obs[:, :9], obs[:, 12:]], dim=1)
vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1)
velocity = self.velocity_planner(vel_obs)
if self.lin_vel_x is not None:
velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])
velocity *= self.command_scale
self.transition.actions = self.actor_critic.act(obs, velocity=velocity)[0].detach()
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()

View File

@ -75,7 +75,8 @@ class ActorCriticRecurrent(ActorCritic):
def act(self, observations, masks=None, hidden_states=None, velocity=None):
if velocity is not None:
observations[..., 9] = velocity.squeeze()
# print(velocity.squeeze())
observations[..., 9] = velocity.squeeze(-1)
# vel_obs = torch.cat([observations[:, :9], observations[:, 12:]], dim=1)
# velocity = self.velocity_planner(vel_obs)
# velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1])