add h1
This commit is contained in:
parent
4a35df7152
commit
d5ce784850
|
@ -1,5 +1,5 @@
|
||||||
from legged_gym import LEGGED_GYM_ROOT_DIR, envs
|
from legged_gym import LEGGED_GYM_ROOT_DIR, envs
|
||||||
from time import time
|
import time
|
||||||
from warnings import WarningMessage
|
from warnings import WarningMessage
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
|
@ -14,6 +14,7 @@ from typing import Tuple, Dict
|
||||||
from legged_gym import LEGGED_GYM_ROOT_DIR
|
from legged_gym import LEGGED_GYM_ROOT_DIR
|
||||||
from legged_gym.envs.base.base_task import BaseTask
|
from legged_gym.envs.base.base_task import BaseTask
|
||||||
from legged_gym.utils.math import quat_apply_yaw, wrap_to_pi, torch_rand_sqrt_float
|
from legged_gym.utils.math import quat_apply_yaw, wrap_to_pi, torch_rand_sqrt_float
|
||||||
|
from legged_gym.utils.isaacgym_utils import get_euler_xyz
|
||||||
from legged_gym.utils.helpers import class_to_dict
|
from legged_gym.utils.helpers import class_to_dict
|
||||||
from .legged_robot_config import LeggedRobotCfg
|
from .legged_robot_config import LeggedRobotCfg
|
||||||
|
|
||||||
|
@ -60,6 +61,12 @@ class LeggedRobot(BaseTask):
|
||||||
self.torques = self._compute_torques(self.actions).view(self.torques.shape)
|
self.torques = self._compute_torques(self.actions).view(self.torques.shape)
|
||||||
self.gym.set_dof_actuation_force_tensor(self.sim, gymtorch.unwrap_tensor(self.torques))
|
self.gym.set_dof_actuation_force_tensor(self.sim, gymtorch.unwrap_tensor(self.torques))
|
||||||
self.gym.simulate(self.sim)
|
self.gym.simulate(self.sim)
|
||||||
|
if self.cfg.env.test:
|
||||||
|
elapsed_time = self.gym.get_elapsed_time(self.sim)
|
||||||
|
sim_time = self.gym.get_sim_time(self.sim)
|
||||||
|
if sim_time-elapsed_time>0:
|
||||||
|
time.sleep(sim_time-elapsed_time)
|
||||||
|
|
||||||
if self.device == 'cpu':
|
if self.device == 'cpu':
|
||||||
self.gym.fetch_results(self.sim, True)
|
self.gym.fetch_results(self.sim, True)
|
||||||
self.gym.refresh_dof_state_tensor(self.sim)
|
self.gym.refresh_dof_state_tensor(self.sim)
|
||||||
|
@ -86,6 +93,7 @@ class LeggedRobot(BaseTask):
|
||||||
# prepare quantities
|
# prepare quantities
|
||||||
self.base_pos[:] = self.root_states[:, 0:3]
|
self.base_pos[:] = self.root_states[:, 0:3]
|
||||||
self.base_quat[:] = self.root_states[:, 3:7]
|
self.base_quat[:] = self.root_states[:, 3:7]
|
||||||
|
self.rpy[:] = get_euler_xyz(self.base_quat)
|
||||||
self.base_lin_vel[:] = quat_rotate_inverse(self.base_quat, self.root_states[:, 7:10])
|
self.base_lin_vel[:] = quat_rotate_inverse(self.base_quat, self.root_states[:, 7:10])
|
||||||
self.base_ang_vel[:] = quat_rotate_inverse(self.base_quat, self.root_states[:, 10:13])
|
self.base_ang_vel[:] = quat_rotate_inverse(self.base_quat, self.root_states[:, 10:13])
|
||||||
self.projected_gravity[:] = quat_rotate_inverse(self.base_quat, self.gravity_vec)
|
self.projected_gravity[:] = quat_rotate_inverse(self.base_quat, self.gravity_vec)
|
||||||
|
@ -107,6 +115,7 @@ class LeggedRobot(BaseTask):
|
||||||
""" Check if environments need to be reset
|
""" Check if environments need to be reset
|
||||||
"""
|
"""
|
||||||
self.reset_buf = torch.any(torch.norm(self.contact_forces[:, self.termination_contact_indices, :], dim=-1) > 1., dim=1)
|
self.reset_buf = torch.any(torch.norm(self.contact_forces[:, self.termination_contact_indices, :], dim=-1) > 1., dim=1)
|
||||||
|
self.reset_buf |= torch.logical_or(torch.abs(self.rpy[:,1])>1.0, torch.abs(self.rpy[:,0])>0.8)
|
||||||
self.time_out_buf = self.episode_length_buf > self.max_episode_length # no terminal reward for time-outs
|
self.time_out_buf = self.episode_length_buf > self.max_episode_length # no terminal reward for time-outs
|
||||||
self.reset_buf |= self.time_out_buf
|
self.reset_buf |= self.time_out_buf
|
||||||
|
|
||||||
|
@ -394,9 +403,9 @@ class LeggedRobot(BaseTask):
|
||||||
noise_vec[3:6] = noise_scales.ang_vel * noise_level * self.obs_scales.ang_vel
|
noise_vec[3:6] = noise_scales.ang_vel * noise_level * self.obs_scales.ang_vel
|
||||||
noise_vec[6:9] = noise_scales.gravity * noise_level
|
noise_vec[6:9] = noise_scales.gravity * noise_level
|
||||||
noise_vec[9:12] = 0. # commands
|
noise_vec[9:12] = 0. # commands
|
||||||
noise_vec[12:24] = noise_scales.dof_pos * noise_level * self.obs_scales.dof_pos
|
noise_vec[12:12+self.num_actions] = noise_scales.dof_pos * noise_level * self.obs_scales.dof_pos
|
||||||
noise_vec[24:36] = noise_scales.dof_vel * noise_level * self.obs_scales.dof_vel
|
noise_vec[12+self.num_actions:12+2*self.num_actions] = noise_scales.dof_vel * noise_level * self.obs_scales.dof_vel
|
||||||
noise_vec[36:48] = 0. # previous actions
|
noise_vec[12+2*self.num_actions:12+3*self.num_actions] = 0. # previous actions
|
||||||
|
|
||||||
return noise_vec
|
return noise_vec
|
||||||
|
|
||||||
|
@ -418,6 +427,7 @@ class LeggedRobot(BaseTask):
|
||||||
self.dof_pos = self.dof_state.view(self.num_envs, self.num_dof, 2)[..., 0]
|
self.dof_pos = self.dof_state.view(self.num_envs, self.num_dof, 2)[..., 0]
|
||||||
self.dof_vel = self.dof_state.view(self.num_envs, self.num_dof, 2)[..., 1]
|
self.dof_vel = self.dof_state.view(self.num_envs, self.num_dof, 2)[..., 1]
|
||||||
self.base_quat = self.root_states[:, 3:7]
|
self.base_quat = self.root_states[:, 3:7]
|
||||||
|
self.rpy = get_euler_xyz(self.base_quat)
|
||||||
self.base_pos = self.root_states[:self.num_envs, 0:3]
|
self.base_pos = self.root_states[:self.num_envs, 0:3]
|
||||||
self.contact_forces = gymtorch.wrap_tensor(net_contact_forces).view(self.num_envs, -1, 3) # shape: num_envs, num_bodies, xyz axis
|
self.contact_forces = gymtorch.wrap_tensor(net_contact_forces).view(self.num_envs, -1, 3) # shape: num_envs, num_bodies, xyz axis
|
||||||
|
|
||||||
|
@ -628,7 +638,7 @@ class LeggedRobot(BaseTask):
|
||||||
|
|
||||||
def _reward_base_height(self):
|
def _reward_base_height(self):
|
||||||
# Penalize base height away from target
|
# Penalize base height away from target
|
||||||
base_height = torch.mean(self.root_states[:, 2].unsqueeze(1) - self.measured_heights, dim=1)
|
base_height = self.root_states[:, 2]
|
||||||
return torch.square(base_height - self.cfg.rewards.base_height_target)
|
return torch.square(base_height - self.cfg.rewards.base_height_target)
|
||||||
|
|
||||||
def _reward_torques(self):
|
def _reward_torques(self):
|
||||||
|
|
|
@ -19,9 +19,6 @@ class GO2RoughCfg( LeggedRobotCfg ):
|
||||||
'FR_calf_joint': -1.5, # [rad]
|
'FR_calf_joint': -1.5, # [rad]
|
||||||
'RR_calf_joint': -1.5, # [rad]
|
'RR_calf_joint': -1.5, # [rad]
|
||||||
}
|
}
|
||||||
class env(LeggedRobotCfg.env):
|
|
||||||
num_observations = 105
|
|
||||||
num_actions = 10
|
|
||||||
|
|
||||||
class control( LeggedRobotCfg.control ):
|
class control( LeggedRobotCfg.control ):
|
||||||
# PD Drive parameters:
|
# PD Drive parameters:
|
||||||
|
|
|
@ -24,12 +24,35 @@ class H1RoughCfg( LeggedRobotCfg ):
|
||||||
'right_shoulder_yaw_joint' : 0.,
|
'right_shoulder_yaw_joint' : 0.,
|
||||||
'right_elbow_joint' : 0.,
|
'right_elbow_joint' : 0.,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class env(LeggedRobotCfg.env):
|
||||||
|
num_observations = 42
|
||||||
|
num_actions = 10
|
||||||
|
test = False
|
||||||
|
|
||||||
|
|
||||||
class control( LeggedRobotCfg.control ):
|
class control( LeggedRobotCfg.control ):
|
||||||
# PD Drive parameters:
|
# PD Drive parameters:
|
||||||
control_type = 'P'
|
control_type = 'P'
|
||||||
stiffness = {'joint': 20.} # [N*m/rad]
|
# PD Drive parameters:
|
||||||
damping = {'joint': 0.5} # [N*m*s/rad]
|
stiffness = {'hip_yaw': 200,
|
||||||
|
'hip_roll': 200,
|
||||||
|
'hip_pitch': 200,
|
||||||
|
'knee': 300,
|
||||||
|
'ankle': 40,
|
||||||
|
'torso': 300,
|
||||||
|
'shoulder': 100,
|
||||||
|
"elbow":100,
|
||||||
|
} # [N*m/rad]
|
||||||
|
damping = { 'hip_yaw': 5,
|
||||||
|
'hip_roll': 5,
|
||||||
|
'hip_pitch': 5,
|
||||||
|
'knee': 6,
|
||||||
|
'ankle': 2,
|
||||||
|
'torso': 6,
|
||||||
|
'shoulder': 2,
|
||||||
|
"elbow":2,
|
||||||
|
} # [N*m/rad] # [N*m*s/rad]
|
||||||
# action scale: target angle = actionScale * action + defaultAngle
|
# action scale: target angle = actionScale * action + defaultAngle
|
||||||
action_scale = 0.25
|
action_scale = 0.25
|
||||||
# decimation: Number of control action updates @ sim DT per policy DT
|
# decimation: Number of control action updates @ sim DT per policy DT
|
||||||
|
@ -39,15 +62,26 @@ class H1RoughCfg( LeggedRobotCfg ):
|
||||||
file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/h1/urdf/h1.urdf'
|
file = '{LEGGED_GYM_ROOT_DIR}/resources/robots/h1/urdf/h1.urdf'
|
||||||
name = "h1"
|
name = "h1"
|
||||||
foot_name = "ankle"
|
foot_name = "ankle"
|
||||||
penalize_contacts_on = ["thigh", "calf"]
|
penalize_contacts_on = ["hip", "knee"]
|
||||||
terminate_after_contacts_on = ["base"]
|
terminate_after_contacts_on = ["pelvis"]
|
||||||
self_collisions = 1 # 1 to disable, 0 to enable...bitwise filter
|
self_collisions = 1 # 1 to disable, 0 to enable...bitwise filter
|
||||||
|
flip_visual_attachments = False
|
||||||
|
|
||||||
class rewards( LeggedRobotCfg.rewards ):
|
class rewards( LeggedRobotCfg.rewards ):
|
||||||
soft_dof_pos_limit = 0.9
|
soft_dof_pos_limit = 0.9
|
||||||
base_height_target = 0.98
|
base_height_target = 0.98
|
||||||
class scales( LeggedRobotCfg.rewards.scales ):
|
class scales( LeggedRobotCfg.rewards.scales ):
|
||||||
torques = -0.0002
|
tracking_lin_vel = 1.0
|
||||||
|
tracking_ang_vel = 0.5
|
||||||
|
lin_vel_z = -2.0
|
||||||
|
ang_vel_xy = -1.0
|
||||||
|
orientation = -1.0
|
||||||
|
base_height = -100.0
|
||||||
|
dof_acc = -3.5e-8
|
||||||
|
feet_air_time = 1.0
|
||||||
|
collision = 0.0
|
||||||
|
action_rate = -0.01
|
||||||
|
torques = 0.0
|
||||||
dof_pos_limits = -10.0
|
dof_pos_limits = -10.0
|
||||||
|
|
||||||
class H1RoughCfgPPO( LeggedRobotCfgPPO ):
|
class H1RoughCfgPPO( LeggedRobotCfgPPO ):
|
||||||
|
@ -55,6 +89,6 @@ class H1RoughCfgPPO( LeggedRobotCfgPPO ):
|
||||||
entropy_coef = 0.01
|
entropy_coef = 0.01
|
||||||
class runner( LeggedRobotCfgPPO.runner ):
|
class runner( LeggedRobotCfgPPO.runner ):
|
||||||
run_name = ''
|
run_name = ''
|
||||||
experiment_name = 'rough_go2'
|
experiment_name = 'h1'
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ import torch
|
||||||
def play(args):
|
def play(args):
|
||||||
env_cfg, train_cfg = task_registry.get_cfgs(name=args.task)
|
env_cfg, train_cfg = task_registry.get_cfgs(name=args.task)
|
||||||
# override some parameters for testing
|
# override some parameters for testing
|
||||||
env_cfg.env.num_envs = min(env_cfg.env.num_envs, 1)
|
env_cfg.env.num_envs = min(env_cfg.env.num_envs, 100)
|
||||||
env_cfg.terrain.num_rows = 5
|
env_cfg.terrain.num_rows = 5
|
||||||
env_cfg.terrain.num_cols = 5
|
env_cfg.terrain.num_cols = 5
|
||||||
env_cfg.terrain.curriculum = False
|
env_cfg.terrain.curriculum = False
|
||||||
|
@ -22,6 +22,8 @@ def play(args):
|
||||||
env_cfg.domain_rand.randomize_friction = False
|
env_cfg.domain_rand.randomize_friction = False
|
||||||
env_cfg.domain_rand.push_robots = False
|
env_cfg.domain_rand.push_robots = False
|
||||||
|
|
||||||
|
env_cfg.env.test = True
|
||||||
|
|
||||||
# prepare environment
|
# prepare environment
|
||||||
env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg)
|
env, _ = task_registry.make_env(name=args.task, args=args, env_cfg=env_cfg)
|
||||||
obs = env.get_observations()
|
obs = env.get_observations()
|
||||||
|
|
|
@ -680,7 +680,7 @@ Stephen Brawner (brawner@gmail.com)
|
||||||
</link>
|
</link>
|
||||||
<joint
|
<joint
|
||||||
name="torso_joint"
|
name="torso_joint"
|
||||||
type="revolute">
|
type="fixed">
|
||||||
<origin
|
<origin
|
||||||
xyz="0 0 0"
|
xyz="0 0 0"
|
||||||
rpy="0 0 0" />
|
rpy="0 0 0" />
|
||||||
|
@ -738,7 +738,7 @@ Stephen Brawner (brawner@gmail.com)
|
||||||
</link>
|
</link>
|
||||||
<joint
|
<joint
|
||||||
name="left_shoulder_pitch_joint"
|
name="left_shoulder_pitch_joint"
|
||||||
type="revolute">
|
type="fixed">
|
||||||
<origin
|
<origin
|
||||||
xyz="0.0055 0.15535 0.42999"
|
xyz="0.0055 0.15535 0.42999"
|
||||||
rpy="0.43633 0 0" />
|
rpy="0.43633 0 0" />
|
||||||
|
@ -796,7 +796,7 @@ Stephen Brawner (brawner@gmail.com)
|
||||||
</link>
|
</link>
|
||||||
<joint
|
<joint
|
||||||
name="left_shoulder_roll_joint"
|
name="left_shoulder_roll_joint"
|
||||||
type="revolute">
|
type="fixed">
|
||||||
<origin
|
<origin
|
||||||
xyz="-0.0055 0.0565 -0.0165"
|
xyz="-0.0055 0.0565 -0.0165"
|
||||||
rpy="-0.43633 0 0" />
|
rpy="-0.43633 0 0" />
|
||||||
|
@ -854,7 +854,7 @@ Stephen Brawner (brawner@gmail.com)
|
||||||
</link>
|
</link>
|
||||||
<joint
|
<joint
|
||||||
name="left_shoulder_yaw_joint"
|
name="left_shoulder_yaw_joint"
|
||||||
type="revolute">
|
type="fixed">
|
||||||
<origin
|
<origin
|
||||||
xyz="0 0 -0.1343"
|
xyz="0 0 -0.1343"
|
||||||
rpy="0 0 0" />
|
rpy="0 0 0" />
|
||||||
|
@ -912,7 +912,7 @@ Stephen Brawner (brawner@gmail.com)
|
||||||
</link>
|
</link>
|
||||||
<joint
|
<joint
|
||||||
name="left_elbow_joint"
|
name="left_elbow_joint"
|
||||||
type="revolute">
|
type="fixed">
|
||||||
<origin
|
<origin
|
||||||
xyz="0.0185 0 -0.198"
|
xyz="0.0185 0 -0.198"
|
||||||
rpy="0 0 0" />
|
rpy="0 0 0" />
|
||||||
|
@ -970,7 +970,7 @@ Stephen Brawner (brawner@gmail.com)
|
||||||
</link>
|
</link>
|
||||||
<joint
|
<joint
|
||||||
name="right_shoulder_pitch_joint"
|
name="right_shoulder_pitch_joint"
|
||||||
type="revolute">
|
type="fixed">
|
||||||
<origin
|
<origin
|
||||||
xyz="0.0055 -0.15535 0.42999"
|
xyz="0.0055 -0.15535 0.42999"
|
||||||
rpy="-0.43633 0 0" />
|
rpy="-0.43633 0 0" />
|
||||||
|
@ -1028,7 +1028,7 @@ Stephen Brawner (brawner@gmail.com)
|
||||||
</link>
|
</link>
|
||||||
<joint
|
<joint
|
||||||
name="right_shoulder_roll_joint"
|
name="right_shoulder_roll_joint"
|
||||||
type="revolute">
|
type="fixed">
|
||||||
<origin
|
<origin
|
||||||
xyz="-0.0055 -0.0565 -0.0165"
|
xyz="-0.0055 -0.0565 -0.0165"
|
||||||
rpy="0.43633 0 0" />
|
rpy="0.43633 0 0" />
|
||||||
|
@ -1086,7 +1086,7 @@ Stephen Brawner (brawner@gmail.com)
|
||||||
</link>
|
</link>
|
||||||
<joint
|
<joint
|
||||||
name="right_shoulder_yaw_joint"
|
name="right_shoulder_yaw_joint"
|
||||||
type="revolute">
|
type="fixed">
|
||||||
<origin
|
<origin
|
||||||
xyz="0 0 -0.1343"
|
xyz="0 0 -0.1343"
|
||||||
rpy="0 0 0" />
|
rpy="0 0 0" />
|
||||||
|
@ -1144,7 +1144,7 @@ Stephen Brawner (brawner@gmail.com)
|
||||||
</link>
|
</link>
|
||||||
<joint
|
<joint
|
||||||
name="right_elbow_joint"
|
name="right_elbow_joint"
|
||||||
type="revolute">
|
type="fixed">
|
||||||
<origin
|
<origin
|
||||||
xyz="0.0185 0 -0.198"
|
xyz="0.0185 0 -0.198"
|
||||||
rpy="0 0 0" />
|
rpy="0 0 0" />
|
||||||
|
|
Loading…
Reference in New Issue