Adding some logging and correcting first obs.
This commit is contained in:
parent
94cf0da3ba
commit
3c2b19f955
|
@ -230,8 +230,8 @@ class CaTAgent:
|
|||
print(f"p_gains: {self.p_gains}")
|
||||
|
||||
self.commands = np.zeros(3)
|
||||
self.actions = torch.zeros((1, 12))
|
||||
self.last_actions = torch.zeros(12)
|
||||
self.actions = np.zeros((1, 12))
|
||||
self.last_actions = np.zeros((1,12))
|
||||
self.gravity_vector = np.zeros(3)
|
||||
self.dof_pos = np.zeros(12)
|
||||
self.dof_vel = np.zeros(12)
|
||||
|
@ -239,9 +239,11 @@ class CaTAgent:
|
|||
self.body_angular_vel = np.zeros(3)
|
||||
self.joint_pos_target = np.zeros(12)
|
||||
self.joint_vel_target = np.zeros(12)
|
||||
self.prev_joint_acc = None
|
||||
self.torques = np.zeros(12)
|
||||
self.contact_state = np.ones(4)
|
||||
self.foot_contact_forces_mag = np.zeros(4)
|
||||
self.prev_foot_contact_forces_mag = np.zeros(4)
|
||||
self.test = 0
|
||||
|
||||
def wait_for_state(self):
|
||||
|
@ -256,10 +258,16 @@ class CaTAgent:
|
|||
joint_state = self.robot.getJointStates()
|
||||
if joint_state is not None:
|
||||
self.gravity_vector = self.robot.getGravityInBody()
|
||||
self.prev_dof_pos = self.dof_pos.copy()
|
||||
self.dof_pos = np.array(joint_state['q'])[self.unitree_to_policy_map[:, 1]]
|
||||
self.prev_dof_vel = self.dof_vel.copy()
|
||||
self.dof_vel = np.array(joint_state['dq'])[self.unitree_to_policy_map[:, 1]]
|
||||
self.body_angular_vel = self.robot.getIMU()["gyro"]
|
||||
self.foot_contact_forces_mag = self.robot.getFootContact()
|
||||
self.body_linear_vel = self.robot.getLinVel()
|
||||
try:
|
||||
self.foot_contact_forces_mag = self.robot.getFootContact()
|
||||
except:
|
||||
pass
|
||||
|
||||
ob = np.concatenate(
|
||||
(
|
||||
|
@ -268,7 +276,7 @@ class CaTAgent:
|
|||
self.gravity_vector[:, 0],
|
||||
self.dof_pos * 1.0,
|
||||
self.dof_vel * 0.05,
|
||||
self.actions[0]
|
||||
self.last_actions[0]
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
|
@ -320,6 +328,15 @@ class CaTAgent:
|
|||
self.time = time.time()
|
||||
obs = self.get_obs()
|
||||
|
||||
joint_acc = np.abs(self.prev_dof_vel - self.dof_vel) / self.dt
|
||||
if self.prev_joint_acc is None:
|
||||
self.prev_joint_acc = np.zeros_like(joint_acc)
|
||||
joint_jerk = np.abs(self.prev_joint_acc - joint_acc) / self.dt
|
||||
self.prev_joint_acc = joint_acc.copy()
|
||||
|
||||
foot_contact_rate = np.abs(self.foot_contact_forces_mag - self.prev_foot_contact_forces_mag)
|
||||
self.prev_foot_contact_forces_mag = self.foot_contact_forces_mag.copy()
|
||||
|
||||
infos = {
|
||||
"joint_pos": self.dof_pos[np.newaxis, :],
|
||||
"joint_vel": self.dof_vel[np.newaxis, :],
|
||||
|
@ -331,7 +348,10 @@ class CaTAgent:
|
|||
"body_linear_vel_cmd": self.commands[np.newaxis, 0:2],
|
||||
"body_angular_vel_cmd": self.commands[np.newaxis, 2:],
|
||||
"torques": self.torques,
|
||||
"foot_contact_forces_mag": self.foot_contact_forces_mag.copy()
|
||||
"foot_contact_forces_mag": self.foot_contact_forces_mag.copy(),
|
||||
"joint_acc": joint_acc[np.newaxis, :],
|
||||
"joint_jerk": joint_jerk[np.newaxis, :],
|
||||
"foot_contact_rate": foot_contact_rate[np.newaxis, :],
|
||||
}
|
||||
|
||||
self.timestep += 1
|
||||
|
|
|
@ -321,6 +321,8 @@ class WalkTheseWaysAgent:
|
|||
self.joint_vel_target = np.zeros(12)
|
||||
self.torques = np.zeros(12)
|
||||
self.contact_state = np.ones(4)
|
||||
self.foot_contact_forces_mag = np.zeros(4)
|
||||
self.prev_foot_contact_forces_mag = np.zeros(4)
|
||||
self.test = 0
|
||||
|
||||
self.gait_indices = torch.zeros(self.num_envs, dtype=torch.float)
|
||||
|
@ -345,6 +347,10 @@ class WalkTheseWaysAgent:
|
|||
self.gravity_vector = self.robot.getGravityInBody()
|
||||
self.dof_pos = np.array(joint_state['q'])[self.unitree_to_policy_map[:, 1]]
|
||||
self.dof_vel = np.array(joint_state['dq'])[self.unitree_to_policy_map[:, 1]]
|
||||
try:
|
||||
self.foot_contact_forces_mag = self.robot.getFootContact()
|
||||
except:
|
||||
pass
|
||||
|
||||
if reset_timer:
|
||||
self.reset_gait_indices()
|
||||
|
@ -464,6 +470,9 @@ class WalkTheseWaysAgent:
|
|||
self.clock_inputs[:, 2] = torch.sin(2 * np.pi * self.foot_indices[2])
|
||||
self.clock_inputs[:, 3] = torch.sin(2 * np.pi * self.foot_indices[3])
|
||||
|
||||
foot_contact_rate = np.abs(self.foot_contact_forces_mag - self.prev_foot_contact_forces_mag)
|
||||
self.prev_foot_contact_forces_mag = self.foot_contact_forces_mag.copy()
|
||||
|
||||
infos = {
|
||||
"joint_pos": self.dof_pos[np.newaxis, :],
|
||||
"joint_vel": self.dof_vel[np.newaxis, :],
|
||||
|
@ -476,6 +485,8 @@ class WalkTheseWaysAgent:
|
|||
"body_linear_vel_cmd": self.commands[:, 0:2],
|
||||
"body_angular_vel_cmd": self.commands[:, 2:],
|
||||
"privileged_obs": None,
|
||||
"foot_contact_rate": foot_contact_rate[np.newaxis, :],
|
||||
"foot_contact_forces_mag": self.foot_contact_forces_mag.copy(),
|
||||
}
|
||||
|
||||
self.timestep += 1
|
||||
|
|
|
@ -95,7 +95,7 @@ class GO2Real():
|
|||
gyro = self.state.gyro
|
||||
quat = self.state.quat
|
||||
temp = self.state.imu_temp
|
||||
return {'accel': accel, 'gyro': gyro, 'quat': quat, 'temp': temp}
|
||||
return {'accel': np.array(accel), 'gyro': np.array(gyro), 'quat': np.array(quat), 'temp': temp}
|
||||
|
||||
def getFootContacts(self):
|
||||
"""Returns the raw foot contact forces"""
|
||||
|
|
|
@ -414,6 +414,12 @@ class Go2Sim:
|
|||
if self.render and (self.step_counter % self.render_ds_ratio) == 0:
|
||||
self.viewer.sync()
|
||||
|
||||
def getLinVel(self):
|
||||
_, q = self.getPose()
|
||||
world_R_body = Rotation.from_quat([q[1], q[2], q[3], q[0]]).as_matrix()
|
||||
body_v = world_R_body.T@self.data.qvel[0:3].reshape(3,1)
|
||||
return body_v
|
||||
|
||||
def stepHighlevel(self, vx, vy, omega_z, body_z_offset=0, step_height = 0.08, kp=[2, 0.5, 0.5], ki=[0.02, 0.01, 0.01]):
|
||||
policy_info = {}
|
||||
if self.step_counter % (self.control_dt // self.dt) == 0:
|
||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue