From 694a5dbca836fb118fb85c2a249a532d8e68996f Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Mon, 15 Jul 2024 17:35:11 +0200 Subject: [PATCH] fix color --- lerobot/common/robot_devices/robots/koch.py | 5 +++ lerobot/scripts/control_robot.py | 40 ++++++++++++++------- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/lerobot/common/robot_devices/robots/koch.py b/lerobot/common/robot_devices/robots/koch.py index be8f2866..c6d1a4d4 100644 --- a/lerobot/common/robot_devices/robots/koch.py +++ b/lerobot/common/robot_devices/robots/koch.py @@ -477,7 +477,9 @@ class KochRobot: # Read follower position follower_pos = {} for name in self.follower_arms: + now = time.perf_counter() follower_pos[name] = self.follower_arms[name].read("Present_Position") + self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - now # Create state by concatenating follower current position state = [] @@ -489,7 +491,10 @@ class KochRobot: # Capture images from cameras images = {} for name in self.cameras: + now = time.perf_counter() images[name] = self.cameras[name].async_read() + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] + self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - now # Populate output dictionnaries and format to pytorch obs_dict = {} diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 9c691fe4..1ee11005 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -95,6 +95,7 @@ import tqdm from huggingface_hub import create_branch from omegaconf import DictConfig from PIL import Image +from termcolor import colored # from safetensors.torch import load_file, save_file from lerobot.common.datasets.compute_stats import compute_stats @@ -137,7 +138,7 @@ def none_or_int(value): return int(value) -def log_control_info(robot, dt_s, episode_index=None, frame_index=None): +def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None): log_items = [] if episode_index is not None: log_items += [f"ep:{episode_index}"] @@ -170,7 +171,12 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None): if key in robot.logs: log_dt(f"dtR{name}", robot.logs[key]) - logging.info(" ".join(log_items)) + info_str = " ".join(log_items) + if fps is not None: + actual_fps = 1 / dt_s + if actual_fps < fps - 1: + info_str = colored(info_str, "yellow") + logging.info(info_str) def get_is_headless(): @@ -201,7 +207,7 @@ def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | Non busy_wait(1 / fps - dt_s) dt_s = time.perf_counter() - now - log_control_info(robot, dt_s) + log_control_info(robot, dt_s, fps=fps) if teleop_time_s is not None and time.perf_counter() - start_time > teleop_time_s: break @@ -272,7 +278,7 @@ def record_dataset( busy_wait(1 / fps - dt_s) dt_s = time.perf_counter() - now - log_control_info(robot, dt_s) + log_control_info(robot, dt_s, fps=fps) timestamp = time.perf_counter() - start_time @@ -352,7 +358,7 @@ def record_dataset( busy_wait(1 / fps - dt_s) dt_s = time.perf_counter() - now - log_control_info(robot, dt_s) + log_control_info(robot, dt_s, fps=fps) timestamp = time.perf_counter() - start_time @@ -396,7 +402,7 @@ def record_dataset( done[-1] = True ep_dict["next.done"] = done - ep_path = episodes_dir / f"episode_{episode_index}.safetensors" + ep_path = episodes_dir / f"episode_{episode_index}.pth" print("Saving episode dictionary...") torch.save(ep_dict, ep_path) @@ -458,7 +464,7 @@ def record_dataset( logging.info("Concatenating episodes") ep_dicts = [] for episode_index in tqdm.tqdm(range(num_episodes)): - ep_path = episodes_dir / f"episode_{episode_index}.safetensors" + ep_path = episodes_dir / f"episode_{episode_index}.pth" ep_dict = torch.load(ep_path) ep_dicts.append(ep_dict) data_dict = concatenate_episodes(ep_dicts) @@ -534,16 +540,18 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat busy_wait(1 / fps - dt_s) dt_s = time.perf_counter() - now - log_control_info(robot, dt_s) + log_control_info(robot, dt_s, fps=fps) def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None): # TODO(rcadene): Add option to record eval dataset and logs - policy.eval() # Check device is available device = get_safe_torch_device(hydra_cfg.device, log=True) + policy.eval() + policy.to(device) + torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True set_global_seed(hydra_cfg.seed) @@ -561,24 +569,30 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run with ( torch.inference_mode(), - torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(), + torch.autocast(device_type=device.type) + if device.type == "cuda" and hydra_cfg.use_amp + else nullcontext(), ): # add batch dimension to 1 for name in observation: observation[name] = observation[name].unsqueeze(0) + if device.type == "mps": + for name in observation: + observation[name] = observation[name].to(device) + action = policy.select_action(observation) # remove batch dimension action = action.squeeze(0) - robot.send_action(action) + robot.send_action(action.to("cpu")) dt_s = time.perf_counter() - now busy_wait(1 / fps - dt_s) dt_s = time.perf_counter() - now - log_control_info(robot, dt_s) + log_control_info(robot, dt_s, fps=fps) if run_time_s is not None and time.perf_counter() - start_time > run_time_s: break @@ -643,7 +657,7 @@ if __name__ == "__main__": default=1, help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.", ) - parser.add_argument( + parser_record.add_argument( "--push-to-hub", type=int, default=1,