import copy import time import os import numpy as np import torch from go1_gym_deploy.utils.logger import MultiLogger class DeploymentRunner: def __init__(self, experiment_name="unnamed", se=None, log_root="."): self.agents = {} self.policy = None self.command_profile = None self.logger = MultiLogger() self.se = se self.vision_server = None self.log_root = log_root self.init_log_filename() self.control_agent_name = None self.command_agent_name = None self.triggered_commands = {i: None for i in range(4)} # command profiles for each action button on the controller self.button_states = np.zeros(4) self.is_currently_probing = False self.is_currently_logging = [False, False, False, False] def init_log_filename(self): datetime = time.strftime("%Y/%m_%d/%H_%M_%S") for i in range(100): try: os.makedirs(f"{self.log_root}/{datetime}_{i}") self.log_filename = f"{self.log_root}/{datetime}_{i}/log.pkl" return except FileExistsError: continue def add_open_loop_agent(self, agent, name): self.agents[name] = agent self.logger.add_robot(name, agent.env.cfg) def add_control_agent(self, agent, name): self.control_agent_name = name self.agents[name] = agent self.logger.add_robot(name, agent.env.cfg) def add_vision_server(self, vision_server): self.vision_server = vision_server def set_command_agents(self, name): self.command_agent = name def add_policy(self, policy): self.policy = policy def add_command_profile(self, command_profile): self.command_profile = command_profile def calibrate(self, wait=True, low=False): # first, if the robot is not in nominal pose, move slowly to the nominal pose for agent_name in self.agents.keys(): if hasattr(self.agents[agent_name], "get_obs"): agent = self.agents[agent_name] agent.get_obs() joint_pos = agent.dof_pos if low: final_goal = np.array([0., 0.3, -0.7, 0., 0.3, -0.7, 0., 0.3, -0.7, 0., 0.3, -0.7,]) else: final_goal = np.zeros(12) nominal_joint_pos = agent.default_dof_pos print(f"About to calibrate; the robot will stand [Press R2 to calibrate]") while wait: self.button_states = self.command_profile.get_buttons() if self.command_profile.state_estimator.right_lower_right_switch_pressed: self.command_profile.state_estimator.right_lower_right_switch_pressed = False break cal_action = np.zeros((agent.num_envs, agent.num_actions)) target_sequence = [] target = joint_pos - nominal_joint_pos while np.max(np.abs(target - final_goal)) > 0.01: target -= np.clip((target - final_goal), -0.05, 0.05) target_sequence += [copy.deepcopy(target)] for target in target_sequence: next_target = target if isinstance(agent.cfg, dict): hip_reduction = agent.cfg["control"]["hip_scale_reduction"] action_scale = agent.cfg["control"]["action_scale"] else: hip_reduction = agent.cfg.control.hip_scale_reduction action_scale = agent.cfg.control.action_scale next_target[[0, 3, 6, 9]] /= hip_reduction next_target = next_target / action_scale cal_action[:, 0:12] = next_target agent.step(torch.from_numpy(cal_action)) agent.get_obs() time.sleep(0.05) print("Starting pose calibrated [Press R2 to start controller]") while True: self.button_states = self.command_profile.get_buttons() if self.command_profile.state_estimator.right_lower_right_switch_pressed: self.command_profile.state_estimator.right_lower_right_switch_pressed = False break for agent_name in self.agents.keys(): obs = self.agents[agent_name].reset() if agent_name == self.control_agent_name: control_obs = obs return control_obs def run(self, num_log_steps=1000000000, max_steps=100000000, logging=True): assert self.control_agent_name is not None, "cannot deploy, runner has no control agent!" assert self.policy is not None, "cannot deploy, runner has no policy!" assert self.command_profile is not None, "cannot deploy, runner has no command profile!" # TODO: add basic test for comms for agent_name in self.agents.keys(): obs = self.agents[agent_name].reset() if agent_name == self.control_agent_name: control_obs = obs control_obs = self.calibrate(wait=True) # now, run control loop try: for i in range(max_steps): policy_info = {} action = self.policy(control_obs, policy_info) for agent_name in self.agents.keys(): obs, ret, done, info = self.agents[agent_name].step(action) info.update(policy_info) info.update({"observation": obs, "reward": ret, "done": done, "timestep": i, "time": i * self.agents[self.control_agent_name].dt, "action": action, "rpy": self.agents[self.control_agent_name].se.get_rpy(), "torques": self.agents[self.control_agent_name].torques}) if logging: self.logger.log(agent_name, info) if agent_name == self.control_agent_name: control_obs, control_ret, control_done, control_info = obs, ret, done, info # bad orientation emergency stop rpy = self.agents[self.control_agent_name].se.get_rpy() if abs(rpy[0]) > 1.6 or abs(rpy[1]) > 1.6: self.calibrate(wait=False, low=True) # check for logging command prev_button_states = self.button_states[:] self.button_states = self.command_profile.get_buttons() if self.command_profile.state_estimator.left_lower_left_switch_pressed: if not self.is_currently_probing: print("START LOGGING") self.is_currently_probing = True self.agents[self.control_agent_name].set_probing(True) self.init_log_filename() self.logger.reset() else: print("SAVE LOG") self.is_currently_probing = False self.agents[self.control_agent_name].set_probing(False) # calibrate, log, and then resume control control_obs = self.calibrate(wait=False) self.logger.save(self.log_filename) self.init_log_filename() self.logger.reset() time.sleep(1) control_obs = self.agents[self.control_agent_name].reset() self.command_profile.state_estimator.left_lower_left_switch_pressed = False for button in range(4): if self.command_profile.currently_triggered[button]: if not self.is_currently_logging[button]: print("START LOGGING") self.is_currently_logging[button] = True self.init_log_filename() self.logger.reset() else: if self.is_currently_logging[button]: print("SAVE LOG") self.is_currently_logging[button] = False # calibrate, log, and then resume control control_obs = self.calibrate(wait=False) self.logger.save(self.log_filename) self.init_log_filename() self.logger.reset() time.sleep(1) control_obs = self.agents[self.control_agent_name].reset() if self.command_profile.state_estimator.right_lower_right_switch_pressed: control_obs = self.calibrate(wait=False) time.sleep(1) self.command_profile.state_estimator.right_lower_right_switch_pressed = False # self.button_states = self.command_profile.get_buttons() while not self.command_profile.state_estimator.right_lower_right_switch_pressed: time.sleep(0.01) # self.button_states = self.command_profile.get_buttons() self.command_profile.state_estimator.right_lower_right_switch_pressed = False # finally, return to the nominal pose control_obs = self.calibrate(wait=False) self.logger.save(self.log_filename) except KeyboardInterrupt: self.logger.save(self.log_filename)