walk-these-ways-go2/go1_gym_deploy/utils/deployment_runner.py

223 lines
9.6 KiB
Python
Raw Normal View History

2024-01-28 17:11:38 +08:00
import copy
import time
import os
import numpy as np
import torch
from go1_gym_deploy.utils.logger import MultiLogger
class DeploymentRunner:
def __init__(self, experiment_name="unnamed", se=None, log_root="."):
self.agents = {}
self.policy = None
self.command_profile = None
self.logger = MultiLogger()
self.se = se
self.vision_server = None
self.log_root = log_root
self.init_log_filename()
self.control_agent_name = None
self.command_agent_name = None
self.triggered_commands = {i: None for i in range(4)} # command profiles for each action button on the controller
self.button_states = np.zeros(4)
self.is_currently_probing = False
self.is_currently_logging = [False, False, False, False]
def init_log_filename(self):
datetime = time.strftime("%Y/%m_%d/%H_%M_%S")
for i in range(100):
try:
os.makedirs(f"{self.log_root}/{datetime}_{i}")
self.log_filename = f"{self.log_root}/{datetime}_{i}/log.pkl"
return
except FileExistsError:
continue
def add_open_loop_agent(self, agent, name):
self.agents[name] = agent
self.logger.add_robot(name, agent.env.cfg)
def add_control_agent(self, agent, name):
self.control_agent_name = name
self.agents[name] = agent
self.logger.add_robot(name, agent.env.cfg)
def add_vision_server(self, vision_server):
self.vision_server = vision_server
def set_command_agents(self, name):
self.command_agent = name
def add_policy(self, policy):
self.policy = policy
def add_command_profile(self, command_profile):
self.command_profile = command_profile
def calibrate(self, wait=True, low=False):
# first, if the robot is not in nominal pose, move slowly to the nominal pose
for agent_name in self.agents.keys():
if hasattr(self.agents[agent_name], "get_obs"):
agent = self.agents[agent_name]
agent.get_obs()
joint_pos = agent.dof_pos
if low:
final_goal = np.array([0., 0.3, -0.7,
0., 0.3, -0.7,
0., 0.3, -0.7,
0., 0.3, -0.7,])
else:
final_goal = np.zeros(12)
nominal_joint_pos = agent.default_dof_pos
print(f"About to calibrate; the robot will stand [Press R2 to calibrate]")
while wait:
self.button_states = self.command_profile.get_buttons()
if self.command_profile.state_estimator.right_lower_right_switch_pressed:
self.command_profile.state_estimator.right_lower_right_switch_pressed = False
break
cal_action = np.zeros((agent.num_envs, agent.num_actions))
target_sequence = []
target = joint_pos - nominal_joint_pos
while np.max(np.abs(target - final_goal)) > 0.01:
target -= np.clip((target - final_goal), -0.05, 0.05)
target_sequence += [copy.deepcopy(target)]
for target in target_sequence:
next_target = target
if isinstance(agent.cfg, dict):
hip_reduction = agent.cfg["control"]["hip_scale_reduction"]
action_scale = agent.cfg["control"]["action_scale"]
else:
hip_reduction = agent.cfg.control.hip_scale_reduction
action_scale = agent.cfg.control.action_scale
next_target[[0, 3, 6, 9]] /= hip_reduction
next_target = next_target / action_scale
cal_action[:, 0:12] = next_target
agent.step(torch.from_numpy(cal_action))
agent.get_obs()
time.sleep(0.05)
print("Starting pose calibrated [Press R2 to start controller]")
while True:
self.button_states = self.command_profile.get_buttons()
if self.command_profile.state_estimator.right_lower_right_switch_pressed:
self.command_profile.state_estimator.right_lower_right_switch_pressed = False
break
for agent_name in self.agents.keys():
obs = self.agents[agent_name].reset()
if agent_name == self.control_agent_name:
control_obs = obs
return control_obs
def run(self, num_log_steps=1000000000, max_steps=100000000, logging=True):
assert self.control_agent_name is not None, "cannot deploy, runner has no control agent!"
assert self.policy is not None, "cannot deploy, runner has no policy!"
assert self.command_profile is not None, "cannot deploy, runner has no command profile!"
# TODO: add basic test for comms
for agent_name in self.agents.keys():
obs = self.agents[agent_name].reset()
if agent_name == self.control_agent_name:
control_obs = obs
control_obs = self.calibrate(wait=True)
# now, run control loop
try:
for i in range(max_steps):
policy_info = {}
action = self.policy(control_obs, policy_info)
for agent_name in self.agents.keys():
obs, ret, done, info = self.agents[agent_name].step(action)
info.update(policy_info)
info.update({"observation": obs, "reward": ret, "done": done, "timestep": i,
"time": i * self.agents[self.control_agent_name].dt, "action": action, "rpy": self.agents[self.control_agent_name].se.get_rpy(), "torques": self.agents[self.control_agent_name].torques})
if logging: self.logger.log(agent_name, info)
if agent_name == self.control_agent_name:
control_obs, control_ret, control_done, control_info = obs, ret, done, info
# bad orientation emergency stop
rpy = self.agents[self.control_agent_name].se.get_rpy()
if abs(rpy[0]) > 1.6 or abs(rpy[1]) > 1.6:
self.calibrate(wait=False, low=True)
# check for logging command
prev_button_states = self.button_states[:]
self.button_states = self.command_profile.get_buttons()
if self.command_profile.state_estimator.left_lower_left_switch_pressed:
if not self.is_currently_probing:
print("START LOGGING")
self.is_currently_probing = True
self.agents[self.control_agent_name].set_probing(True)
self.init_log_filename()
self.logger.reset()
else:
print("SAVE LOG")
self.is_currently_probing = False
self.agents[self.control_agent_name].set_probing(False)
# calibrate, log, and then resume control
control_obs = self.calibrate(wait=False)
self.logger.save(self.log_filename)
self.init_log_filename()
self.logger.reset()
time.sleep(1)
control_obs = self.agents[self.control_agent_name].reset()
self.command_profile.state_estimator.left_lower_left_switch_pressed = False
for button in range(4):
if self.command_profile.currently_triggered[button]:
if not self.is_currently_logging[button]:
print("START LOGGING")
self.is_currently_logging[button] = True
self.init_log_filename()
self.logger.reset()
else:
if self.is_currently_logging[button]:
print("SAVE LOG")
self.is_currently_logging[button] = False
# calibrate, log, and then resume control
control_obs = self.calibrate(wait=False)
self.logger.save(self.log_filename)
self.init_log_filename()
self.logger.reset()
time.sleep(1)
control_obs = self.agents[self.control_agent_name].reset()
if self.command_profile.state_estimator.right_lower_right_switch_pressed:
control_obs = self.calibrate(wait=False)
time.sleep(1)
self.command_profile.state_estimator.right_lower_right_switch_pressed = False
# self.button_states = self.command_profile.get_buttons()
while not self.command_profile.state_estimator.right_lower_right_switch_pressed:
time.sleep(0.01)
# self.button_states = self.command_profile.get_buttons()
self.command_profile.state_estimator.right_lower_right_switch_pressed = False
# finally, return to the nominal pose
control_obs = self.calibrate(wait=False)
self.logger.save(self.log_filename)
except KeyboardInterrupt:
self.logger.save(self.log_filename)