walk these ways controller added and tested in simulation.
This commit is contained in:
parent
174c5a3460
commit
cf73faf024
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,477 @@
|
|||
import os
|
||||
import pickle as pkl
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
def loadParameters(path):
|
||||
with open(path + "/parameters.pkl", "rb") as file:
|
||||
pkl_cfg = pkl.load(file)
|
||||
cfg = pkl_cfg["Cfg"]
|
||||
return cfg
|
||||
|
||||
|
||||
def class_to_dict(obj) -> dict:
|
||||
if not hasattr(obj, "__dict__"):
|
||||
return obj
|
||||
result = {}
|
||||
for key in dir(obj):
|
||||
if key.startswith("_") or key == "terrain":
|
||||
continue
|
||||
element = []
|
||||
val = getattr(obj, key)
|
||||
if isinstance(val, list):
|
||||
for item in val:
|
||||
element.append(class_to_dict(item))
|
||||
else:
|
||||
element = class_to_dict(val)
|
||||
result[key] = element
|
||||
return result
|
||||
|
||||
|
||||
class HistoryWrapper:
|
||||
def __init__(self, env):
|
||||
self.env = env
|
||||
|
||||
if isinstance(self.env.cfg, dict):
|
||||
self.obs_history_length = self.env.cfg["env"]["num_observation_history"]
|
||||
else:
|
||||
self.obs_history_length = self.env.cfg.env.num_observation_history
|
||||
self.num_obs_history = self.obs_history_length * self.env.num_obs
|
||||
self.obs_history = torch.zeros(
|
||||
self.env.num_envs,
|
||||
self.num_obs_history,
|
||||
dtype=torch.float,
|
||||
device=self.env.device,
|
||||
requires_grad=False,
|
||||
)
|
||||
self.num_privileged_obs = self.env.num_privileged_obs
|
||||
|
||||
def step(self, action):
|
||||
obs, rew, done, info = self.env.step(action)
|
||||
privileged_obs = info["privileged_obs"]
|
||||
|
||||
self.obs_history = torch.cat(
|
||||
(self.obs_history[:, self.env.num_obs :], obs), dim=-1
|
||||
)
|
||||
return (
|
||||
{
|
||||
"obs": obs,
|
||||
"privileged_obs": privileged_obs,
|
||||
"obs_history": self.obs_history,
|
||||
},
|
||||
rew,
|
||||
done,
|
||||
info,
|
||||
)
|
||||
|
||||
def get_observations(self):
|
||||
obs = self.env.get_observations()
|
||||
privileged_obs = self.env.get_privileged_observations()
|
||||
self.obs_history = torch.cat(
|
||||
(self.obs_history[:, self.env.num_obs :], obs), dim=-1
|
||||
)
|
||||
return {
|
||||
"obs": obs,
|
||||
"privileged_obs": privileged_obs,
|
||||
"obs_history": self.obs_history,
|
||||
}
|
||||
|
||||
def get_obs(self):
|
||||
obs = self.env.get_obs()
|
||||
privileged_obs = self.env.get_privileged_observations()
|
||||
self.obs_history = torch.cat(
|
||||
(self.obs_history[:, self.env.num_obs :], obs), dim=-1
|
||||
)
|
||||
return {
|
||||
"obs": obs,
|
||||
"privileged_obs": privileged_obs,
|
||||
"obs_history": self.obs_history,
|
||||
}
|
||||
|
||||
def reset_idx(
|
||||
self, env_ids
|
||||
): # it might be a problem that this isn't getting called!!
|
||||
ret = self.env.reset_idx(env_ids)
|
||||
self.obs_history[env_ids, :] = 0
|
||||
return ret
|
||||
|
||||
def reset(self):
|
||||
ret = self.env.reset()
|
||||
privileged_obs = self.env.get_privileged_observations()
|
||||
self.obs_history[:, :] = 0
|
||||
return {
|
||||
"obs": ret,
|
||||
"privileged_obs": privileged_obs,
|
||||
"obs_history": self.obs_history,
|
||||
}
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.env, name)
|
||||
|
||||
|
||||
class CommandInterface:
|
||||
def __init__(self, limits=None):
|
||||
self.limits = limits
|
||||
gaits = {
|
||||
"pronking": [0, 0, 0],
|
||||
"trotting": [0.5, 0, 0],
|
||||
"bounding": [0, 0.5, 0],
|
||||
"pacing": [0, 0, 0.5],
|
||||
}
|
||||
self.x_vel_cmd, self.y_vel_cmd, self.yaw_vel_cmd = 0.0, 0.0, 0.0
|
||||
self.body_height_cmd = 0.0
|
||||
self.step_frequency_cmd = 3.0
|
||||
self.gait = torch.tensor(gaits["trotting"])
|
||||
self.footswing_height_cmd = 0.03
|
||||
self.pitch_cmd = 0.0
|
||||
self.roll_cmd = 0.0
|
||||
self.stance_width_cmd = 0.0
|
||||
|
||||
def get_command(self):
|
||||
command = np.zeros((19,))
|
||||
command[0] = self.x_vel_cmd
|
||||
command[1] = self.y_vel_cmd
|
||||
command[2] = self.yaw_vel_cmd
|
||||
command[3] = self.body_height_cmd
|
||||
command[4] = self.step_frequency_cmd
|
||||
command[5:8] = self.gait
|
||||
command[8] = 0.5
|
||||
command[9] = self.footswing_height_cmd
|
||||
command[10] = self.pitch_cmd
|
||||
command[11] = self.roll_cmd
|
||||
command[12] = self.stance_width_cmd
|
||||
return command, False
|
||||
|
||||
|
||||
class Policy:
|
||||
def __init__(self, checkpoint_path):
|
||||
self.body = torch.jit.load(
|
||||
os.path.join(checkpoint_path, "checkpoints/body_latest.jit")
|
||||
)
|
||||
self.adaptation_module = torch.jit.load(
|
||||
os.path.join(checkpoint_path, "checkpoints/adaptation_module_latest.jit")
|
||||
)
|
||||
|
||||
def __call__(self, obs, info):
|
||||
latent = self.adaptation_module.forward(obs["obs_history"].to("cpu"))
|
||||
action = self.body.forward(
|
||||
torch.cat((obs["obs_history"].to("cpu"), latent), dim=-1)
|
||||
)
|
||||
info["latent"] = latent
|
||||
return action
|
||||
|
||||
|
||||
class WalkTheseWaysAgent:
|
||||
def __init__(self, cfg, command_profile, robot):
|
||||
self.robot = robot
|
||||
if not isinstance(cfg, dict):
|
||||
cfg = class_to_dict(cfg)
|
||||
self.cfg = cfg
|
||||
self.command_profile = command_profile
|
||||
# self.lcm_bridge = LCMBridgeClient(robot_name=self.robot_name)
|
||||
self.dt = self.cfg["control"]["decimation"] * self.cfg["sim"]["dt"]
|
||||
self.timestep = 0
|
||||
|
||||
self.num_obs = self.cfg["env"]["num_observations"]
|
||||
self.num_envs = 1
|
||||
self.num_privileged_obs = self.cfg["env"]["num_privileged_obs"]
|
||||
self.num_actions = self.cfg["env"]["num_actions"]
|
||||
self.num_commands = self.cfg["commands"]["num_commands"]
|
||||
self.device = "cpu"
|
||||
|
||||
if "obs_scales" in self.cfg.keys():
|
||||
self.obs_scales = self.cfg["obs_scales"]
|
||||
else:
|
||||
self.obs_scales = self.cfg["normalization"]["obs_scales"]
|
||||
|
||||
self.commands_scale = np.array(
|
||||
[
|
||||
self.obs_scales["lin_vel"],
|
||||
self.obs_scales["lin_vel"],
|
||||
self.obs_scales["ang_vel"],
|
||||
self.obs_scales["body_height_cmd"],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
self.obs_scales["footswing_height_cmd"],
|
||||
self.obs_scales["body_pitch_cmd"],
|
||||
# 0, self.obs_scales["body_pitch_cmd"],
|
||||
self.obs_scales["body_roll_cmd"],
|
||||
self.obs_scales["stance_width_cmd"],
|
||||
self.obs_scales["stance_length_cmd"],
|
||||
self.obs_scales["aux_reward_cmd"],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
]
|
||||
)[: self.num_commands]
|
||||
|
||||
joint_names = [
|
||||
"FL_hip_joint",
|
||||
"FL_thigh_joint",
|
||||
"FL_calf_joint",
|
||||
"FR_hip_joint",
|
||||
"FR_thigh_joint",
|
||||
"FR_calf_joint",
|
||||
"RL_hip_joint",
|
||||
"RL_thigh_joint",
|
||||
"RL_calf_joint",
|
||||
"RR_hip_joint",
|
||||
"RR_thigh_joint",
|
||||
"RR_calf_joint",
|
||||
]
|
||||
|
||||
policy_joint_names = joint_names
|
||||
|
||||
unitree_joint_names = [
|
||||
"FR_hip_joint",
|
||||
"FR_thigh_joint",
|
||||
"FR_calf_joint",
|
||||
"FL_hip_joint",
|
||||
"FL_thigh_joint",
|
||||
"FL_calf_joint",
|
||||
"RR_hip_joint",
|
||||
"RR_thigh_joint",
|
||||
"RR_calf_joint",
|
||||
"RL_hip_joint",
|
||||
"RL_thigh_joint",
|
||||
"RL_calf_joint",
|
||||
]
|
||||
|
||||
policy_to_unitree_map=[]
|
||||
unitree_to_policy_map=[]
|
||||
for i,policy_joint_name in enumerate(policy_joint_names):
|
||||
id = np.where([name==policy_joint_name for name in unitree_joint_names])[0][0]
|
||||
policy_to_unitree_map.append((i,id))
|
||||
self.policy_to_unitree_map=np.array(policy_to_unitree_map)
|
||||
|
||||
for i,unitree_joint_name in enumerate(unitree_joint_names):
|
||||
id = np.where([name==unitree_joint_name for name in policy_joint_names])[0][0]
|
||||
unitree_to_policy_map.append((i,id))
|
||||
self.unitree_to_policy_map=np.array(unitree_to_policy_map)
|
||||
|
||||
self.default_dof_pos = np.array(
|
||||
[
|
||||
self.cfg["init_state"]["default_joint_angles"][name]
|
||||
for name in joint_names
|
||||
]
|
||||
)
|
||||
try:
|
||||
self.default_dof_pos_scale = np.array(
|
||||
[
|
||||
self.cfg["init_state"]["default_hip_scales"],
|
||||
self.cfg["init_state"]["default_thigh_scales"],
|
||||
self.cfg["init_state"]["default_calf_scales"],
|
||||
self.cfg["init_state"]["default_hip_scales"],
|
||||
self.cfg["init_state"]["default_thigh_scales"],
|
||||
self.cfg["init_state"]["default_calf_scales"],
|
||||
self.cfg["init_state"]["default_hip_scales"],
|
||||
self.cfg["init_state"]["default_thigh_scales"],
|
||||
self.cfg["init_state"]["default_calf_scales"],
|
||||
self.cfg["init_state"]["default_hip_scales"],
|
||||
self.cfg["init_state"]["default_thigh_scales"],
|
||||
self.cfg["init_state"]["default_calf_scales"],
|
||||
]
|
||||
)
|
||||
except KeyError:
|
||||
self.default_dof_pos_scale = np.ones(12)
|
||||
self.default_dof_pos = self.default_dof_pos * self.default_dof_pos_scale
|
||||
|
||||
self.p_gains = np.zeros(12)
|
||||
self.d_gains = np.zeros(12)
|
||||
for i in range(12):
|
||||
joint_name = joint_names[i]
|
||||
found = False
|
||||
for dof_name in self.cfg["control"]["stiffness"].keys():
|
||||
if dof_name in joint_name:
|
||||
self.p_gains[i] = self.cfg["control"]["stiffness"][dof_name]
|
||||
self.d_gains[i] = self.cfg["control"]["damping"][dof_name]
|
||||
found = True
|
||||
if not found:
|
||||
self.p_gains[i] = 0.0
|
||||
self.d_gains[i] = 0.0
|
||||
if self.cfg["control"]["control_type"] in ["P", "V"]:
|
||||
print(
|
||||
f"PD gain of joint {joint_name} were not defined, setting them to zero"
|
||||
)
|
||||
|
||||
print(f"p_gains: {self.p_gains}")
|
||||
|
||||
self.commands = np.zeros((1, self.num_commands))
|
||||
self.actions = torch.zeros(12)
|
||||
self.last_actions = torch.zeros(12)
|
||||
self.gravity_vector = np.zeros(3)
|
||||
self.dof_pos = np.zeros(12)
|
||||
self.dof_vel = np.zeros(12)
|
||||
self.body_linear_vel = np.zeros(3)
|
||||
self.body_angular_vel = np.zeros(3)
|
||||
self.joint_pos_target = np.zeros(12)
|
||||
self.joint_vel_target = np.zeros(12)
|
||||
self.torques = np.zeros(12)
|
||||
self.contact_state = np.ones(4)
|
||||
self.test = 0
|
||||
|
||||
self.gait_indices = torch.zeros(self.num_envs, dtype=torch.float)
|
||||
self.clock_inputs = torch.zeros(self.num_envs, 4, dtype=torch.float)
|
||||
|
||||
if "obs_scales" in self.cfg.keys():
|
||||
self.obs_scales = self.cfg["obs_scales"]
|
||||
else:
|
||||
self.obs_scales = self.cfg["normalization"]["obs_scales"]
|
||||
|
||||
def wait_for_state(self):
|
||||
# return self.lcm_bridge.getStates(timeout=2)
|
||||
pass
|
||||
|
||||
def get_obs(self):
|
||||
cmds, reset_timer = self.command_profile.get_command()
|
||||
self.commands[:, :] = cmds[: self.num_commands]
|
||||
|
||||
# self.state = self.wait_for_state()
|
||||
joint_state = self.robot.getJointStates()
|
||||
if joint_state is not None:
|
||||
self.gravity_vector = self.robot.getGravityInBody()
|
||||
self.dof_pos = joint_state['q'][self.unitree_to_policy_map[:,1]]
|
||||
self.dof_vel = joint_state['dq'][self.unitree_to_policy_map[:,1]]
|
||||
|
||||
if reset_timer:
|
||||
self.reset_gait_indices()
|
||||
|
||||
ob = np.concatenate(
|
||||
(
|
||||
self.gravity_vector.reshape(1, -1),
|
||||
self.commands * self.commands_scale,
|
||||
(self.dof_pos - self.default_dof_pos).reshape(1, -1)
|
||||
* self.obs_scales["dof_pos"],
|
||||
self.dof_vel.reshape(1, -1) * self.obs_scales["dof_vel"],
|
||||
torch.clip(
|
||||
self.actions,
|
||||
-self.cfg["normalization"]["clip_actions"],
|
||||
self.cfg["normalization"]["clip_actions"],
|
||||
)
|
||||
.cpu()
|
||||
.detach()
|
||||
.numpy()
|
||||
.reshape(1, -1),
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
|
||||
if self.cfg["env"]["observe_two_prev_actions"]:
|
||||
ob = np.concatenate(
|
||||
(ob, self.last_actions.cpu().detach().numpy().reshape(1, -1)), axis=1
|
||||
)
|
||||
|
||||
if self.cfg["env"]["observe_clock_inputs"]:
|
||||
ob = np.concatenate((ob, self.clock_inputs), axis=1)
|
||||
|
||||
return torch.tensor(ob, device=self.device).float()
|
||||
|
||||
def get_privileged_observations(self):
|
||||
return None
|
||||
|
||||
def publish_action(self, action, hard_reset=False):
|
||||
# command_for_robot = UnitreeLowCommand()
|
||||
self.joint_pos_target = (
|
||||
action[0, :12].detach().cpu().numpy() * self.cfg["control"]["action_scale"]
|
||||
).flatten()
|
||||
self.joint_pos_target[[0, 3, 6, 9]] *= self.cfg["control"][
|
||||
"hip_scale_reduction"
|
||||
]
|
||||
self.joint_pos_target += self.default_dof_pos
|
||||
self.joint_vel_target = np.zeros(12)
|
||||
# command_for_robot.q_des = self.joint_pos_target
|
||||
# command_for_robot.dq_des = self.joint_vel_target
|
||||
# command_for_robot.kp = self.p_gains
|
||||
# command_for_robot.kd = self.d_gains
|
||||
# command_for_robot.tau_ff = np.zeros(12)
|
||||
if hard_reset:
|
||||
command_for_robot.id = -1
|
||||
|
||||
self.torques = (self.joint_pos_target - self.dof_pos) * self.p_gains + (
|
||||
self.joint_vel_target - self.dof_vel
|
||||
) * self.d_gains
|
||||
# self.lcm_bridge.sendCommands(command_for_robot)
|
||||
self.robot.setCommands(self.joint_pos_target[self.policy_to_unitree_map[:,1]],\
|
||||
self.joint_vel_target[self.policy_to_unitree_map[:,1]],\
|
||||
self.p_gains[self.policy_to_unitree_map[:,1]],\
|
||||
self.d_gains[self.policy_to_unitree_map[:,1]],
|
||||
np.zeros(12))
|
||||
|
||||
def reset(self):
|
||||
self.actions = torch.zeros(12)
|
||||
self.time = time.time()
|
||||
self.timestep = 0
|
||||
return self.get_obs()
|
||||
|
||||
def reset_gait_indices(self):
|
||||
self.gait_indices = torch.zeros(self.num_envs, dtype=torch.float)
|
||||
|
||||
def step(self, actions, hard_reset=False):
|
||||
clip_actions = self.cfg["normalization"]["clip_actions"]
|
||||
self.last_actions = self.actions[:]
|
||||
self.actions = torch.clip(actions[0:1, :], -clip_actions, clip_actions)
|
||||
self.publish_action(self.actions, hard_reset=hard_reset)
|
||||
# time.sleep(max(self.dt - (time.time() - self.time), 0))
|
||||
if self.timestep % 100 == 0:
|
||||
print(f"frq: {1 / (time.time() - self.time)} Hz")
|
||||
self.time = time.time()
|
||||
obs = self.get_obs()
|
||||
|
||||
# clock accounting
|
||||
frequencies = self.commands[:, 4]
|
||||
phases = self.commands[:, 5]
|
||||
offsets = self.commands[:, 6]
|
||||
if self.num_commands == 8:
|
||||
bounds = 0
|
||||
else:
|
||||
bounds = self.commands[:, 7]
|
||||
self.gait_indices = torch.remainder(
|
||||
self.gait_indices + self.dt * frequencies, 1.0
|
||||
)
|
||||
|
||||
if (
|
||||
"pacing_offset" in self.cfg["commands"]
|
||||
and self.cfg["commands"]["pacing_offset"]
|
||||
):
|
||||
self.foot_indices = [
|
||||
self.gait_indices + phases + offsets + bounds,
|
||||
self.gait_indices + bounds,
|
||||
self.gait_indices + offsets,
|
||||
self.gait_indices + phases,
|
||||
]
|
||||
else:
|
||||
self.foot_indices = [
|
||||
self.gait_indices + phases + offsets + bounds,
|
||||
self.gait_indices + offsets,
|
||||
self.gait_indices + bounds,
|
||||
self.gait_indices + phases,
|
||||
]
|
||||
self.clock_inputs[:, 0] = torch.sin(2 * np.pi * self.foot_indices[0])
|
||||
self.clock_inputs[:, 1] = torch.sin(2 * np.pi * self.foot_indices[1])
|
||||
self.clock_inputs[:, 2] = torch.sin(2 * np.pi * self.foot_indices[2])
|
||||
self.clock_inputs[:, 3] = torch.sin(2 * np.pi * self.foot_indices[3])
|
||||
|
||||
infos = {
|
||||
"joint_pos": self.dof_pos[np.newaxis, :],
|
||||
"joint_vel": self.dof_vel[np.newaxis, :],
|
||||
"joint_pos_target": self.joint_pos_target[np.newaxis, :],
|
||||
"joint_vel_target": self.joint_vel_target[np.newaxis, :],
|
||||
"body_linear_vel": self.body_linear_vel[np.newaxis, :],
|
||||
"body_angular_vel": self.body_angular_vel[np.newaxis, :],
|
||||
"contact_state": self.contact_state[np.newaxis, :],
|
||||
"clock_inputs": self.clock_inputs[np.newaxis, :],
|
||||
"body_linear_vel_cmd": self.commands[:, 0:2],
|
||||
"body_angular_vel_cmd": self.commands[:, 2:],
|
||||
"privileged_obs": None,
|
||||
}
|
||||
|
||||
self.timestep += 1
|
||||
return obs, None, None, infos
|
|
@ -8,7 +8,7 @@ import pinocchio as pin
|
|||
from pinocchio.robot_wrapper import RobotWrapper
|
||||
from Go2Py import ASSETS_PATH
|
||||
import os
|
||||
|
||||
from scipy.spatial.transform import Rotation
|
||||
class Go2Sim:
|
||||
def __init__(self, render=True, dt=0.002):
|
||||
|
||||
|
@ -144,6 +144,12 @@ class Go2Sim:
|
|||
'nle':nle
|
||||
}
|
||||
|
||||
def getGravityInBody(self):
|
||||
_, q = self.getPose()
|
||||
R = Rotation.from_quat([q[1], q[2], q[3], q[0]]).as_matrix()
|
||||
g_in_body = R.T@np.array([0.0, 0.0, -1.0]).reshape(3, 1)
|
||||
return g_in_body
|
||||
|
||||
def overheat(self):
|
||||
return False
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -12,7 +12,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -22,7 +22,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -92,7 +92,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -101,7 +101,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -125,7 +125,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -144,10 +144,6 @@
|
|||
" self.user_controller_callback = user_controller_callback\n",
|
||||
"\n",
|
||||
" self.state = \"damping\"\n",
|
||||
"\n",
|
||||
" # self.tracking_kp = np.array(4*[200, 200, 350.]).reshape(12)\n",
|
||||
" # self.tracking_kv = np.array(12*[10.])\n",
|
||||
"\n",
|
||||
" self.tracking_kp = np.array(4*[150, 150, 150.]).reshape(12)\n",
|
||||
" self.tracking_kv = np.array(12*[3.])\n",
|
||||
" self.damping_kv = np.array(12*[2.])\n",
|
||||
|
@ -284,30 +280,7 @@
|
|||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Exception in thread Thread-9:\n",
|
||||
"Traceback (most recent call last):\n",
|
||||
" File \"/home/rstaion/miniconda3/envs/b1py/lib/python3.8/threading.py\", line 932, in _bootstrap_inner\n",
|
||||
" self.run()\n",
|
||||
" File \"/home/rstaion/miniconda3/envs/b1py/lib/python3.8/threading.py\", line 870, in run\n",
|
||||
" self._target(*self._args, **self._kwargs)\n",
|
||||
" File \"/tmp/ipykernel_130541/2701327530.py\", line 76, in simUpdate\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" File \"/home/rstaion/projects/rooholla/locomotion/Go2Py/Go2Py/sim/mujoco.py\", line 122, in step\n",
|
||||
" tau = np.diag(self.kp)@(self.q_des-q).reshape(12,1)+ \\\n",
|
||||
"numpy.core._exceptions.UFuncTypeError: ufunc 'subtract' did not contain a loop with signature matching types (dtype('float64'), dtype('<U1')) -> None\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fsm = FSM(robot, remote, safety)"
|
||||
]
|
||||
|
|
|
@ -148,7 +148,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.18"
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"robot = Go2Sim()\n",
|
||||
"robot.standUp()"
|
||||
"robot.standUpReset()"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -29,10 +29,10 @@
|
|||
"import mujoco\n",
|
||||
"import time\n",
|
||||
"q,dq = robot.getJointStates()\n",
|
||||
"robot.standUp()\n",
|
||||
"robot.standUpReset()\n",
|
||||
"for i in range(100000):\n",
|
||||
" q,dq = robot.getJointStates()\n",
|
||||
" tau = 20*np.eye(12)@(robot.q0 - q).reshape(12,1)\n",
|
||||
" state = robot.getJointStates()\n",
|
||||
" tau = 20*np.eye(12)@(robot.q0 - state['q']).reshape(12,1)\n",
|
||||
" robot.setCommands(np.zeros(12), np.zeros(12), np.zeros(12), np.zeros(12), tau)\n",
|
||||
" robot.step()"
|
||||
]
|
||||
|
@ -82,13 +82,6 @@
|
|||
"source": [
|
||||
"robot.getSiteJacobian('FR_foot')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from Go2Py.sim.mujoco import Go2Sim\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"robot = Go2Sim()\n",
|
||||
"robot.standUpReset()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from Go2Py.controllers.walk_these_ways import *"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"checkpoint_path = \"/home/rstaion/projects/rooholla/locomotion/Go2Py/Go2Py/assets/checkpoints/walk_these_ways/\"\n",
|
||||
"\n",
|
||||
"cfg = loadParameters(checkpoint_path)\n",
|
||||
"policy = Policy(checkpoint_path)\n",
|
||||
"command_profile = CommandInterface()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"p_gains: [20. 20. 20. 20. 20. 20. 20. 20. 20. 20. 20. 20.]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent = WalkTheseWaysAgent(cfg, command_profile, robot)\n",
|
||||
"agent = HistoryWrapper(agent)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"control_dt = cfg[\"control\"][\"decimation\"] * cfg[\"sim\"][\"dt\"]\n",
|
||||
"simulation_dt = robot.dt\n",
|
||||
"obs = agent.reset()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"frq: 12.822735624382831 Hz\n",
|
||||
"frq: 47.954632763194 Hz\n",
|
||||
"frq: 53.309743511528 Hz\n",
|
||||
"frq: 54.65674559220214 Hz\n",
|
||||
"frq: 45.29730547005778 Hz\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"robot.reset()\n",
|
||||
"obs = agent.reset()\n",
|
||||
"for i in range(5000):\n",
|
||||
" policy_info = {}\n",
|
||||
" action = policy(obs, policy_info)\n",
|
||||
" if i % (control_dt // simulation_dt) == 0:\n",
|
||||
" obs, ret, done, info = agent.step(action)\n",
|
||||
" robot.step()\n",
|
||||
" command_profile.yaw_vel_cmd = 1.2\n",
|
||||
" command_profile.x_vel_cmd = 0.8\n",
|
||||
" command_profile.y_vel_cmd = 0.0\n",
|
||||
" command_profile.stance_width_cmd=0.2\n",
|
||||
" command_profile.footswing_height_cmd=-0.05\n",
|
||||
" command_profile.step_frequency_cmd = 2.5\n",
|
||||
" time.sleep(robot.dt/4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "b1-env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.18"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
Loading…
Reference in New Issue