Go2Py_SIM/Go2Py/control/cat.py

302 lines
10 KiB
Python
Raw Normal View History

import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class RunningMeanStd(nn.Module):
def __init__(self, shape = (), epsilon=1e-08):
super(RunningMeanStd, self).__init__()
self.register_buffer("running_mean", torch.zeros(shape))
self.register_buffer("running_var", torch.ones(shape))
self.register_buffer("count", torch.ones(()))
self.epsilon = epsilon
def forward(self, obs, update = True):
if update:
self.update(obs)
return (obs - self.running_mean) / torch.sqrt(self.running_var + self.epsilon)
def update(self, x):
"""Updates the mean, var and count from a batch of samples."""
batch_mean = torch.mean(x, dim=0)
batch_var = torch.var(x, correction=0, dim=0)
batch_count = x.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)
def update_from_moments(self, batch_mean, batch_var, batch_count):
"""Updates from batch mean, variance and count moments."""
self.running_mean, self.running_var, self.count = update_mean_var_count_from_moments(
self.running_mean, self.running_var, self.count, batch_mean, batch_var, batch_count
)
def update_mean_var_count_from_moments(
mean, var, count, batch_mean, batch_var, batch_count
):
"""Updates the mean, var and count using the previous mean, var, count and batch values."""
delta = batch_mean - mean
tot_count = count + batch_count
new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + torch.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
return new_mean, new_var, new_count
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
class Agent(nn.Module):
def __init__(self):
super().__init__()
self.critic = nn.Sequential(
layer_init(nn.Linear(45, 512)),
nn.ELU(),
layer_init(nn.Linear(512, 256)),
nn.ELU(),
layer_init(nn.Linear(256, 128)),
nn.ELU(),
layer_init(nn.Linear(128, 1), std=1.0),
)
self.actor_mean = nn.Sequential(
layer_init(nn.Linear(45, 512)),
nn.ELU(),
layer_init(nn.Linear(512, 256)),
nn.ELU(),
layer_init(nn.Linear(256, 128)),
nn.ELU(),
layer_init(nn.Linear(128, 12), std=0.01),
)
self.actor_logstd = nn.Parameter(torch.zeros(1, 12))
self.obs_rms = RunningMeanStd(shape = (45,))
self.value_rms = RunningMeanStd(shape = ())
def get_value(self, x):
return self.critic(x)
def get_action(self, x):
action_mean = self.actor_mean(x)
return action_mean
class Policy:
def __init__(self, checkpoint_path):
self.agent = Agent()
actor_sd = torch.load(checkpoint_path)
self.agent.load_state_dict(actor_sd)
def __call__(self, obs, info):
with torch.no_grad():
action = self.agent.get_action(self.agent.obs_rms(obs.unsqueeze(0), update = False))
return action
class CommandInterface:
def __init__(self, limits=None):
self.limits = limits
self.x_vel_cmd, self.y_vel_cmd, self.yaw_vel_cmd = 0.0, 0.0, 0.0
def get_command(self):
command = np.zeros((3,))
command[0] = self.x_vel_cmd
command[1] = self.y_vel_cmd
command[2] = self.yaw_vel_cmd
return command, False
class CaTAgent:
def __init__(self, command_profile, robot):
self.robot = robot
self.command_profile = command_profile
# self.lcm_bridge = LCMBridgeClient(robot_name=self.robot_name)
self.sim_dt = 0.001
self.decimation = 20
self.dt = self.sim_dt * self.decimation
self.timestep = 0
self.device = "cpu"
joint_names = [
"FL_hip_joint",
"FL_thigh_joint",
"FL_calf_joint",
"FR_hip_joint",
"FR_thigh_joint",
"FR_calf_joint",
"RL_hip_joint",
"RL_thigh_joint",
"RL_calf_joint",
"RR_hip_joint",
"RR_thigh_joint",
"RR_calf_joint",
]
policy_joint_names = joint_names
unitree_joint_names = [
"FR_hip_joint",
"FR_thigh_joint",
"FR_calf_joint",
"FL_hip_joint",
"FL_thigh_joint",
"FL_calf_joint",
"RR_hip_joint",
"RR_thigh_joint",
"RR_calf_joint",
"RL_hip_joint",
"RL_thigh_joint",
"RL_calf_joint",
]
policy_to_unitree_map = []
unitree_to_policy_map = []
for i, policy_joint_name in enumerate(policy_joint_names):
id = np.where([name == policy_joint_name for name in unitree_joint_names])[0][0]
policy_to_unitree_map.append((i, id))
self.policy_to_unitree_map = np.array(policy_to_unitree_map).astype(np.uint32)
for i, unitree_joint_name in enumerate(unitree_joint_names):
id = np.where([name == unitree_joint_name for name in policy_joint_names])[0][0]
unitree_to_policy_map.append((i, id))
self.unitree_to_policy_map = np.array(unitree_to_policy_map).astype(np.uint32)
default_joint_angles = {
"FL_hip_joint": 0.1,
"RL_hip_joint": 0.1,
"FR_hip_joint": -0.1,
"RR_hip_joint": -0.1,
"FL_thigh_joint": 0.8,
"RL_thigh_joint": 1.0,
"FR_thigh_joint": 0.8,
"RR_thigh_joint": 1.0,
"FL_calf_joint": -1.5,
"RL_calf_joint": -1.5,
"FR_calf_joint": -1.5,
"RR_calf_joint": -1.5
}
self.default_dof_pos = np.array(
[
default_joint_angles[name]
for name in joint_names
]
)
self.default_dof_pos = self.default_dof_pos
self.p_gains = np.zeros(12)
self.d_gains = np.zeros(12)
for i in range(12):
self.p_gains[i] = 20.0
self.d_gains[i] = 0.5
print(f"p_gains: {self.p_gains}")
self.commands = np.zeros(3)
self.actions = torch.zeros((1, 12))
self.last_actions = torch.zeros(12)
self.gravity_vector = np.zeros(3)
self.dof_pos = np.zeros(12)
self.dof_vel = np.zeros(12)
self.body_linear_vel = np.zeros(3)
self.body_angular_vel = np.zeros(3)
self.joint_pos_target = np.zeros(12)
self.joint_vel_target = np.zeros(12)
self.torques = np.zeros(12)
self.contact_state = np.ones(4)
self.test = 0
def wait_for_state(self):
# return self.lcm_bridge.getStates(timeout=2)
pass
def get_obs(self):
cmds, reset_timer = self.command_profile.get_command()
self.commands[:] = cmds
# self.state = self.wait_for_state()
joint_state = self.robot.getJointStates()
if joint_state is not None:
self.gravity_vector = self.robot.getGravityInBody()
self.dof_pos = np.array(joint_state['q'])[self.unitree_to_policy_map[:, 1]]
self.dof_vel = np.array(joint_state['dq'])[self.unitree_to_policy_map[:, 1]]
self.body_angular_vel = self.robot.getIMU()["gyro"]
ob = np.concatenate(
(
self.body_angular_vel * 0.25,
self.commands * np.array([2.0, 2.0, 0.25]),
self.gravity_vector[:, 0],
self.dof_pos * 1.0,
self.dof_vel * 0.05,
self.actions[0]
),
axis=0,
)
return torch.tensor(ob, device=self.device).float()
def publish_action(self, action, hard_reset=False):
# command_for_robot = UnitreeLowCommand()
self.joint_pos_target = (
action[0, :12].detach().cpu().numpy() * 0.25
).flatten()
self.joint_pos_target += self.default_dof_pos
self.joint_vel_target = np.zeros(12)
# command_for_robot.q_des = self.joint_pos_target
# command_for_robot.dq_des = self.joint_vel_target
# command_for_robot.kp = self.p_gains
# command_for_robot.kd = self.d_gains
# command_for_robot.tau_ff = np.zeros(12)
if hard_reset:
command_for_robot.id = -1
self.torques = (self.joint_pos_target - self.dof_pos) * self.p_gains + (
self.joint_vel_target - self.dof_vel
) * self.d_gains
# self.lcm_bridge.sendCommands(command_for_robot)
self.robot.setCommands(self.joint_pos_target[self.policy_to_unitree_map[:, 1]],
self.joint_vel_target[self.policy_to_unitree_map[:, 1]],
self.p_gains[self.policy_to_unitree_map[:, 1]],
self.d_gains[self.policy_to_unitree_map[:, 1]],
np.zeros(12))
def reset(self):
self.actions = torch.zeros((1, 12))
self.time = time.time()
self.timestep = 0
return self.get_obs()
def step(self, actions, hard_reset=False):
self.last_actions = self.actions[:]
self.actions = actions
self.publish_action(self.actions, hard_reset=hard_reset)
# time.sleep(max(self.dt - (time.time() - self.time), 0))
# if self.timestep % 100 == 0:
# print(f"frq: {1 / (time.time() - self.time)} Hz")
self.time = time.time()
obs = self.get_obs()
infos = {
"joint_pos": self.dof_pos[np.newaxis, :],
"joint_vel": self.dof_vel[np.newaxis, :],
"joint_pos_target": self.joint_pos_target[np.newaxis, :],
"joint_vel_target": self.joint_vel_target[np.newaxis, :],
"body_linear_vel": self.body_linear_vel[np.newaxis, :],
"body_angular_vel": self.body_angular_vel[np.newaxis, :],
"contact_state": self.contact_state[np.newaxis, :],
"body_linear_vel_cmd": self.commands[np.newaxis, 0:2],
"body_angular_vel_cmd": self.commands[np.newaxis, 2:],
2024-10-03 03:09:24 +08:00
"torques": self.torques
}
self.timestep += 1
return obs, None, None, infos