Adding some logging and correcting first obs.

This commit is contained in:
jogima-cyber 2024-10-09 20:08:54 +00:00
parent 94cf0da3ba
commit 3c2b19f955
5 changed files with 533 additions and 88 deletions

View File

@ -230,8 +230,8 @@ class CaTAgent:
print(f"p_gains: {self.p_gains}")
self.commands = np.zeros(3)
self.actions = torch.zeros((1, 12))
self.last_actions = torch.zeros(12)
self.actions = np.zeros((1, 12))
self.last_actions = np.zeros((1,12))
self.gravity_vector = np.zeros(3)
self.dof_pos = np.zeros(12)
self.dof_vel = np.zeros(12)
@ -239,9 +239,11 @@ class CaTAgent:
self.body_angular_vel = np.zeros(3)
self.joint_pos_target = np.zeros(12)
self.joint_vel_target = np.zeros(12)
self.prev_joint_acc = None
self.torques = np.zeros(12)
self.contact_state = np.ones(4)
self.foot_contact_forces_mag = np.zeros(4)
self.prev_foot_contact_forces_mag = np.zeros(4)
self.test = 0
def wait_for_state(self):
@ -256,10 +258,16 @@ class CaTAgent:
joint_state = self.robot.getJointStates()
if joint_state is not None:
self.gravity_vector = self.robot.getGravityInBody()
self.prev_dof_pos = self.dof_pos.copy()
self.dof_pos = np.array(joint_state['q'])[self.unitree_to_policy_map[:, 1]]
self.prev_dof_vel = self.dof_vel.copy()
self.dof_vel = np.array(joint_state['dq'])[self.unitree_to_policy_map[:, 1]]
self.body_angular_vel = self.robot.getIMU()["gyro"]
self.foot_contact_forces_mag = self.robot.getFootContact()
self.body_linear_vel = self.robot.getLinVel()
try:
self.foot_contact_forces_mag = self.robot.getFootContact()
except:
pass
ob = np.concatenate(
(
@ -268,7 +276,7 @@ class CaTAgent:
self.gravity_vector[:, 0],
self.dof_pos * 1.0,
self.dof_vel * 0.05,
self.actions[0]
self.last_actions[0]
),
axis=0,
)
@ -320,6 +328,15 @@ class CaTAgent:
self.time = time.time()
obs = self.get_obs()
joint_acc = np.abs(self.prev_dof_vel - self.dof_vel) / self.dt
if self.prev_joint_acc is None:
self.prev_joint_acc = np.zeros_like(joint_acc)
joint_jerk = np.abs(self.prev_joint_acc - joint_acc) / self.dt
self.prev_joint_acc = joint_acc.copy()
foot_contact_rate = np.abs(self.foot_contact_forces_mag - self.prev_foot_contact_forces_mag)
self.prev_foot_contact_forces_mag = self.foot_contact_forces_mag.copy()
infos = {
"joint_pos": self.dof_pos[np.newaxis, :],
"joint_vel": self.dof_vel[np.newaxis, :],
@ -331,7 +348,10 @@ class CaTAgent:
"body_linear_vel_cmd": self.commands[np.newaxis, 0:2],
"body_angular_vel_cmd": self.commands[np.newaxis, 2:],
"torques": self.torques,
"foot_contact_forces_mag": self.foot_contact_forces_mag.copy()
"foot_contact_forces_mag": self.foot_contact_forces_mag.copy(),
"joint_acc": joint_acc[np.newaxis, :],
"joint_jerk": joint_jerk[np.newaxis, :],
"foot_contact_rate": foot_contact_rate[np.newaxis, :],
}
self.timestep += 1

View File

@ -321,6 +321,8 @@ class WalkTheseWaysAgent:
self.joint_vel_target = np.zeros(12)
self.torques = np.zeros(12)
self.contact_state = np.ones(4)
self.foot_contact_forces_mag = np.zeros(4)
self.prev_foot_contact_forces_mag = np.zeros(4)
self.test = 0
self.gait_indices = torch.zeros(self.num_envs, dtype=torch.float)
@ -345,6 +347,10 @@ class WalkTheseWaysAgent:
self.gravity_vector = self.robot.getGravityInBody()
self.dof_pos = np.array(joint_state['q'])[self.unitree_to_policy_map[:, 1]]
self.dof_vel = np.array(joint_state['dq'])[self.unitree_to_policy_map[:, 1]]
try:
self.foot_contact_forces_mag = self.robot.getFootContact()
except:
pass
if reset_timer:
self.reset_gait_indices()
@ -464,6 +470,9 @@ class WalkTheseWaysAgent:
self.clock_inputs[:, 2] = torch.sin(2 * np.pi * self.foot_indices[2])
self.clock_inputs[:, 3] = torch.sin(2 * np.pi * self.foot_indices[3])
foot_contact_rate = np.abs(self.foot_contact_forces_mag - self.prev_foot_contact_forces_mag)
self.prev_foot_contact_forces_mag = self.foot_contact_forces_mag.copy()
infos = {
"joint_pos": self.dof_pos[np.newaxis, :],
"joint_vel": self.dof_vel[np.newaxis, :],
@ -476,6 +485,8 @@ class WalkTheseWaysAgent:
"body_linear_vel_cmd": self.commands[:, 0:2],
"body_angular_vel_cmd": self.commands[:, 2:],
"privileged_obs": None,
"foot_contact_rate": foot_contact_rate[np.newaxis, :],
"foot_contact_forces_mag": self.foot_contact_forces_mag.copy(),
}
self.timestep += 1

View File

@ -95,7 +95,7 @@ class GO2Real():
gyro = self.state.gyro
quat = self.state.quat
temp = self.state.imu_temp
return {'accel': accel, 'gyro': gyro, 'quat': quat, 'temp': temp}
return {'accel': np.array(accel), 'gyro': np.array(gyro), 'quat': np.array(quat), 'temp': temp}
def getFootContacts(self):
"""Returns the raw foot contact forces"""

View File

@ -414,6 +414,12 @@ class Go2Sim:
if self.render and (self.step_counter % self.render_ds_ratio) == 0:
self.viewer.sync()
def getLinVel(self):
_, q = self.getPose()
world_R_body = Rotation.from_quat([q[1], q[2], q[3], q[0]]).as_matrix()
body_v = world_R_body.T@self.data.qvel[0:3].reshape(3,1)
return body_v
def stepHighlevel(self, vx, vy, omega_z, body_z_offset=0, step_height = 0.08, kp=[2, 0.5, 0.5], ki=[0.02, 0.01, 0.01]):
policy_info = {}
if self.step_counter % (self.control_dt // self.dt) == 0:

File diff suppressed because one or more lines are too long