Add log_control_info

This commit is contained in:
Remi Cadene 2024-07-06 17:54:22 +02:00
parent 3ff789c181
commit d83d34d9b3
4 changed files with 157 additions and 66 deletions

View File

@ -6,6 +6,10 @@ from pathlib import Path
from threading import Thread from threading import Thread
import cv2 import cv2
# Using 1 thread to avoid blocking the main thread.
# Especially useful during data collection when other threads are used
# to save the images.
cv2.setNumThreads(1)
import numpy as np import numpy as np
from lerobot.common.robot_devices.cameras.utils import save_color_image from lerobot.common.robot_devices.cameras.utils import save_color_image
@ -120,7 +124,6 @@ class OpenCVCamera:
self.camera = None self.camera = None
self.is_connected = False self.is_connected = False
self.threads = {} self.threads = {}
self.results = {} self.results = {}
self.logs = {} self.logs = {}

View File

@ -154,14 +154,19 @@ def get_group_sync_key(data_name, motor_names):
group_key = f"{data_name}_" + "_".join(motor_names) group_key = f"{data_name}_" + "_".join(motor_names)
return group_key return group_key
def get_thread_name(fn_name, data_name, motor_names): def get_result_name(fn_name, data_name, motor_names):
group_key = get_group_sync_key(data_name, motor_names) group_key = get_group_sync_key(data_name, motor_names)
thread_name = f"{fn_name}_{group_key}" rslt_name = f"{fn_name}_{group_key}"
return thread_name return rslt_name
def get_queue_name(fn_name, data_name, motor_names):
group_key = get_group_sync_key(data_name, motor_names)
queue_name = f"{fn_name}_{group_key}"
return queue_name
def get_log_name(var_name, fn_name, data_name, motor_names): def get_log_name(var_name, fn_name, data_name, motor_names):
thread_name = get_thread_name(fn_name, data_name, motor_names) group_key = get_group_sync_key(data_name, motor_names)
log_name = f"{var_name}_{thread_name}" log_name = f"{var_name}_{fn_name}_{group_key}"
return log_name return log_name
class TorqueMode(enum.Enum): class TorqueMode(enum.Enum):
@ -197,6 +202,18 @@ class DynamixelMotorsBus:
if extra_model_control_table: if extra_model_control_table:
self.model_ctrl_table.update(extra_model_control_table) self.model_ctrl_table.update(extra_model_control_table)
self.port_handler = None
self.packet_handler = None
self.calibration = None
self.is_connected = False
self.group_readers = {}
self.group_writers = {}
self.logs = {}
def connect(self):
if self.is_connected:
raise ValueError(f"KochRobot is already connected.")
self.port_handler = PortHandler(self.port) self.port_handler = PortHandler(self.port)
self.packet_handler = PacketHandler(PROTOCOL_VERSION) self.packet_handler = PacketHandler(PROTOCOL_VERSION)
@ -206,15 +223,13 @@ class DynamixelMotorsBus:
self.port_handler.setBaudRate(BAUD_RATE) self.port_handler.setBaudRate(BAUD_RATE)
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS) self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
self.group_readers = {} # for async_read and async_write
self.group_writers = {} self.thread = None
self.async_read_args = {}
self.calibration = None self.write_queue = Queue()
self.threads = {}
self.queues = {}
self.results = {} self.results = {}
self.logs = {}
self.is_connected = True
@property @property
def motor_names(self) -> list[int]: def motor_names(self) -> list[int]:
@ -258,6 +273,9 @@ class DynamixelMotorsBus:
return values return values
def read(self, data_name, motor_names: list[str] | None = None): def read(self, data_name, motor_names: list[str] | None = None):
if not self.is_connected:
raise ValueError(f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`.")
start_time = time.perf_counter() start_time = time.perf_counter()
if motor_names is None: if motor_names is None:
@ -320,6 +338,9 @@ class DynamixelMotorsBus:
return values return values
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None): def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
if not self.is_connected:
raise ValueError(f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`.")
start_time = time.perf_counter() start_time = time.perf_counter()
if motor_names is None: if motor_names is None:
@ -406,10 +427,20 @@ class DynamixelMotorsBus:
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names) ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names)
self.logs[ts_utc_name] = capture_timestamp_utc() self.logs[ts_utc_name] = capture_timestamp_utc()
def read_loop(self, data_name, motor_names: list[str] | None = None): def read_write_loop(self, async_read_args, write_queue):
while True: while True:
thread_name = get_thread_name("read", data_name, motor_names) for group_name, read_args in async_read_args.items():
self.results[thread_name] = self.read(data_name, motor_names) self.results[group_name] = self.read(*read_args)
if write_queue.empty():
continue
write_args = write_queue.get()
if write_args is None: # A way to terminate the thread
break
self.write(*write_args)
write_queue.task_done()
def async_read(self, data_name, motor_names: list[str] | None = None): def async_read(self, data_name, motor_names: list[str] | None = None):
if motor_names is None: if motor_names is None:
@ -418,32 +449,25 @@ class DynamixelMotorsBus:
if isinstance(motor_names, str): if isinstance(motor_names, str):
motor_names = [motor_names] motor_names = [motor_names]
thread_name = get_thread_name("read", data_name, motor_names) if self.thread is None:
self.thread = Thread(target=self.read_write_loop, args=(self.async_read_args, self.write_queue))
self.thread.daemon = True
self.thread.start()
if thread_name not in self.threads: group_name = get_group_sync_key(data_name, motor_names)
self.threads[thread_name] = Thread(target=self.read_loop, args=(data_name, motor_names)) self.async_read_args[group_name] = (data_name, motor_names)
self.threads[thread_name].daemon = True
self.threads[thread_name].start()
FPS = 200 FPS = 200
num_tries = 0 num_tries = 0
while thread_name not in self.results: while group_name not in self.results:
num_tries += 1 num_tries += 1
time.sleep(1 / FPS) time.sleep(1 / FPS)
if num_tries > FPS: if num_tries > FPS:
if self.threads[thread_name].ident is None and not self.threads[thread_name].is_alive(): if self.thread.ident is None and not self.thread.is_alive():
raise Exception(f"The thread responsible for `self.async_read({data_name}, {motor_names})` took too much time to start. There might be an issue. Verify that `self.threads[thread_name].start()` has been called.") raise Exception(f"The thread responsible for `self.async_read({data_name}, {motor_names})` took too much time to start. There might be an issue. Verify that `self.threads[thread_name].start()` has been called.")
# ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names) # ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names)
return self.results[thread_name] #, self.logs[ts_utc_name] return self.results[group_name] #, self.logs[ts_utc_name]
def write_loop(self, data_name, queue: Queue, motor_names: list[str] | None = None):
while True:
values = queue.get()
if values is None: # A way to terminate the thread
break
self.write(data_name, values, motor_names)
queue.task_done()
def async_write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None): def async_write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
if motor_names is None: if motor_names is None:
@ -457,38 +481,33 @@ class DynamixelMotorsBus:
values = np.array(values) values = np.array(values)
thread_name = get_thread_name("write", data_name, motor_names) if self.thread is None:
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names) self.thread = Thread(target=self.read_write_loop, args=(self.async_read_args, self.write_queue))
self.thread.daemon = True
self.thread.start()
if thread_name not in self.threads: self.write_queue.put((data_name, values, motor_names))
self.queues[thread_name] = Queue()
self.threads[thread_name] = Thread(target=self.write_loop, args=(data_name, self.queues[thread_name], motor_names))
self.threads[thread_name].daemon = True
self.threads[thread_name].start()
self.queues[thread_name].put(values)
FPS = 200 FPS = 200
num_tries = 0 num_tries = 0
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names)
while ts_utc_name not in self.logs: while ts_utc_name not in self.logs:
num_tries += 1 num_tries += 1
time.sleep(1 / FPS) time.sleep(1 / FPS)
if num_tries > FPS: if num_tries > FPS:
if self.threads[thread_name].ident is None and not self.threads[thread_name].is_alive(): if self.thread.ident is None and not self.thread.is_alive():
raise Exception(f"The thread responsible for `self.async_write({data_name}, {values}, {motor_names})` took too much time to start. There might be an issue. Verify that `self.threads[thread_name].start()` has been called.") raise Exception(f"The thread responsible for `self.async_write({data_name}, {values}, {motor_names})` took too much time to start. There might be an issue. Verify that `self.threads[thread_name].start()` has been called.")
return self.logs[ts_utc_name] return self.logs[ts_utc_name]
def __del__(self): def __del__(self):
for thread_name in self.queues:
# Send value that corresponds to `break` logic # Send value that corresponds to `break` logic
self.queues[thread_name].put(None) # if self.queue is not None:
self.queues[thread_name].join() # self.queue.put(None)
# self.queue.join()
for thread_name in self.queues: if self.thread is not None:
self.threads[thread_name].join() self.thread.join()
# TODO(rcadene): find a simple way to exit threads created by async_read
# def read(self, data_name, motor_name: str): # def read(self, data_name, motor_name: str):
# motor_idx, model = self.motors[motor_name] # motor_idx, model = self.motors[motor_name]

View File

@ -1,6 +1,7 @@
import pickle import pickle
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from pathlib import Path from pathlib import Path
import time
import numpy as np import numpy as np
import torch import torch
@ -243,11 +244,20 @@ class KochRobot:
self.leader_arms = self.config.leader_arms self.leader_arms = self.config.leader_arms
self.follower_arms = self.config.follower_arms self.follower_arms = self.config.follower_arms
self.cameras = self.config.cameras self.cameras = self.config.cameras
self.is_connected = False
self.logs = {}
self.async_read = False self.async_read = False
self.async_write = False self.async_write = False
def init_teleop(self): def connect(self):
if self.is_connected:
raise ValueError(f"KochRobot is already connected.")
for name in self.follower_arms:
self.follower_arms[name].connect()
self.leader_arms[name].connect()
if self.calibration_path.exists(): if self.calibration_path.exists():
# Reset all arms before setting calibration # Reset all arms before setting calibration
for name in self.follower_arms: for name in self.follower_arms:
@ -279,7 +289,12 @@ class KochRobot:
for name in self.cameras: for name in self.cameras:
self.cameras[name].connect() self.cameras[name].connect()
self.is_connected = True
def run_calibration(self): def run_calibration(self):
if not self.is_connected:
raise ValueError(f"KochRobot is not connected. You need to run `robot.connect()`.")
calibration = {} calibration = {}
for name in self.follower_arms: for name in self.follower_arms:
@ -301,13 +316,18 @@ class KochRobot:
def teleop_step( def teleop_step(
self, record_data=False self, record_data=False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: ) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
if not self.is_connected:
raise ValueError(f"KochRobot is not connected. You need to run `robot.connect()`.")
# Prepare to assign the positions of the leader to the follower # Prepare to assign the positions of the leader to the follower
leader_pos = {} leader_pos = {}
for name in self.leader_arms: for name in self.leader_arms:
now = time.perf_counter()
if self.async_read: if self.async_read:
leader_pos[name] = self.leader_arms[name].async_read("Present_Position") leader_pos[name] = self.leader_arms[name].async_read("Present_Position")
else: else:
leader_pos[name] = self.leader_arms[name].read("Present_Position") leader_pos[name] = self.leader_arms[name].read("Present_Position")
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - now
follower_goal_pos = {} follower_goal_pos = {}
for name in self.leader_arms: for name in self.leader_arms:
@ -315,10 +335,12 @@ class KochRobot:
# Send action # Send action
for name in self.follower_arms: for name in self.follower_arms:
now = time.perf_counter()
if self.async_write: if self.async_write:
self.follower_arms[name].async_write("Goal_Position", follower_goal_pos[name]) self.follower_arms[name].async_write("Goal_Position", follower_goal_pos[name])
else: else:
self.follower_arms[name].write("Goal_Position", follower_goal_pos[name]) self.follower_arms[name].write("Goal_Position", follower_goal_pos[name])
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - now
# Early exit when recording data is not requested # Early exit when recording data is not requested
if not record_data: if not record_data:
@ -327,10 +349,12 @@ class KochRobot:
# Read follower position # Read follower position
follower_pos = {} follower_pos = {}
for name in self.follower_arms: for name in self.follower_arms:
now = time.perf_counter()
if self.async_read: if self.async_read:
follower_pos[name] = self.follower_arms[name].async_read("Present_Position") follower_pos[name] = self.follower_arms[name].async_read("Present_Position")
else: else:
follower_pos[name] = self.follower_arms[name].read("Present_Position") follower_pos[name] = self.follower_arms[name].read("Present_Position")
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - now
# Create state by concatenating follower current position # Create state by concatenating follower current position
state = [] state = []
@ -349,7 +373,10 @@ class KochRobot:
# Capture images from cameras # Capture images from cameras
images = {} images = {}
for name in self.cameras: for name in self.cameras:
now = time.perf_counter()
images[name] = self.cameras[name].async_read() images[name] = self.cameras[name].async_read()
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - now
# Populate output dictionnaries and format to pytorch # Populate output dictionnaries and format to pytorch
obs_dict, action_dict = {}, {} obs_dict, action_dict = {}, {}
@ -361,6 +388,9 @@ class KochRobot:
return obs_dict, action_dict return obs_dict, action_dict
def capture_observation(self): def capture_observation(self):
if not self.is_connected:
raise ValueError(f"KochRobot is not connected. You need to run `robot.connect()`.")
# Read follower position # Read follower position
follower_pos = {} follower_pos = {}
for name in self.follower_arms: for name in self.follower_arms:
@ -389,6 +419,9 @@ class KochRobot:
return obs_dict return obs_dict
def send_action(self, action): def send_action(self, action):
if not self.is_connected:
raise ValueError(f"KochRobot is not connected. You need to run `robot.connect()`.")
from_idx = 0 from_idx = 0
to_idx = 0 to_idx = 0
follower_goal_pos = {} follower_goal_pos = {}

View File

@ -63,6 +63,7 @@ python lerobot/scripts/control_robot.py run_policy \
import argparse import argparse
import concurrent.futures import concurrent.futures
import logging
import os import os
import shutil import shutil
import time import time
@ -82,7 +83,7 @@ from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.scripts.eval import get_pretrained_policy_path from lerobot.scripts.eval import get_pretrained_policy_path
from lerobot.scripts.push_dataset_to_hub import save_meta_data from lerobot.scripts.push_dataset_to_hub import save_meta_data
@ -97,7 +98,6 @@ def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
img.save(str(path), quality=100) img.save(str(path), quality=100)
def busy_wait(seconds): def busy_wait(seconds):
# Significantly more accurate than `time.sleep`, and mendatory for our use case, # Significantly more accurate than `time.sleep`, and mendatory for our use case,
# but it consumes CPU cycles. # but it consumes CPU cycles.
@ -106,12 +106,41 @@ def busy_wait(seconds):
while time.perf_counter() < end_time: while time.perf_counter() < end_time:
pass pass
def none_or_int(value): def none_or_int(value):
if value == "None": if value == "None":
return None return None
return int(value) return int(value)
def log_control_info(robot, dt_s, episode_index=None, frame_index=None):
log_items = []
if episode_index is not None:
log_items += [f"ep:{episode_index}"]
if frame_index is not None:
log_items += [f"frame:{frame_index}"]
# total step time displayed in milliseconds and its frequency
log_items += [f"dt:{dt_s * 1000:5.2f}={1/ dt_s:3.1f}hz"]
for name in robot.leader_arms:
read_dt_s = robot.logs[f'read_leader_{name}_pos_dt_s']
log_items += [
f"dtRlead{name[0]}:{read_dt_s * 1000:5.2f}={1/ read_dt_s:3.1f}hz",
]
for name in robot.follower_arms:
write_dt_s = robot.logs[f'write_follower_{name}_goal_pos_dt_s']
read_dt_s = robot.logs[f'read_follower_{name}_pos_dt_s']
log_items += [
f"dtRfoll{name[0]}:{write_dt_s * 1000:5.2f}={1/ write_dt_s:3.1f}hz",
f"dtWfoll{name[0]}:{read_dt_s * 1000:5.2f}={1/ read_dt_s:3.1f}hz",
]
for name in robot.cameras:
read_dt_s = robot.logs[f"read_camera_{name}_dt_s"]
async_read_dt_s = robot.logs[f"async_read_camera_{name}_dt_s"]
log_items += [
f"dtRcam{name[0]}:{read_dt_s * 1000:5.2f}={1/read_dt_s:3.1f}hz",
f"dtARcam{name[0]}:{async_read_dt_s * 1000:5.2f}={1/async_read_dt_s:3.1f}hz",
]
logging.info(" ".join(log_items))
######################################################################################## ########################################################################################
# Control modes # Control modes
@ -130,7 +159,7 @@ def teleoperate(robot: Robot, fps: int | None = None):
busy_wait(1 / fps - dt_s) busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}") log_control_info(robot, dt_s)
def record_dataset( def record_dataset(
@ -157,15 +186,18 @@ def record_dataset(
videos_dir.mkdir(parents=True, exist_ok=True) videos_dir.mkdir(parents=True, exist_ok=True)
# Save images using threads to reach high fps (30 and more) # Save images using threads to reach high fps (30 and more)
# Using `with` ensures the program exists smoothly if an execption is raised. # Using `with` to exist smoothly if an execption is raised.
with concurrent.futures.ThreadPoolExecutor() as executor: # Using only 4 worker threads to avoid blocking the main thread.
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
# Execute a few seconds without recording data, to give times
# to the robot devices to connect and start synchronizing.
timestamp = 0 timestamp = 0
start_time = time.perf_counter() start_time = time.perf_counter()
is_warmup_print = False is_warmup_print = False
while timestamp < warmup_time_s: while timestamp < warmup_time_s:
if not is_warmup_print: if not is_warmup_print:
print("Warming up by skipping frames") logging.info("Warming up (no data recording)")
os.system('say "Warmup" &') os.system('say "Warmup" &')
is_warmup_print = True is_warmup_print = True
@ -176,10 +208,11 @@ def record_dataset(
busy_wait(1 / fps - dt_s) busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f} (Warmup)") log_control_info(robot, dt_s)
timestamp = time.perf_counter() - start_time timestamp = time.perf_counter() - start_time
# Start recording all episodes
ep_dicts = [] ep_dicts = []
for episode_index in range(num_episodes): for episode_index in range(num_episodes):
ep_dict = {} ep_dict = {}
@ -189,7 +222,7 @@ def record_dataset(
is_record_print = False is_record_print = False
while timestamp < episode_time_s: while timestamp < episode_time_s:
if not is_record_print: if not is_record_print:
print(f"Recording episode {episode_index}") logging.info(f"Recording episode {episode_index}")
os.system(f'say "Recording episode {episode_index}" &') os.system(f'say "Recording episode {episode_index}" &')
is_record_print = True is_record_print = True
@ -218,11 +251,11 @@ def record_dataset(
busy_wait(1 / fps - dt_s) busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}") log_control_info(robot, dt_s)
timestamp = time.perf_counter() - start_time timestamp = time.perf_counter() - start_time
print("Encoding images to videos") logging.info("Encoding images to videos")
num_frames = frame_index num_frames = frame_index
@ -232,6 +265,7 @@ def record_dataset(
video_path = local_dir / "videos" / fname video_path = local_dir / "videos" / fname
encode_video_frames(tmp_imgs_dir, video_path, fps) encode_video_frames(tmp_imgs_dir, video_path, fps)
# TODO(rcadene): uncomment?
# clean temporary images directory # clean temporary images directory
# shutil.rmtree(tmp_imgs_dir) # shutil.rmtree(tmp_imgs_dir)
@ -304,7 +338,7 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
robot.init_teleop() robot.init_teleop()
print("Replaying episode") logging.info("Replaying episode")
os.system('say "Replaying episode"') os.system('say "Replaying episode"')
for idx in range(from_idx, to_idx): for idx in range(from_idx, to_idx):
@ -317,7 +351,7 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
busy_wait(1 / fps - dt_s) busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}") log_control_info(robot, dt_s)
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig): def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
@ -349,7 +383,7 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
busy_wait(1 / fps - dt_s) busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}") log_control_info(robot, dt_s)
if __name__ == "__main__": if __name__ == "__main__":
@ -406,6 +440,8 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
init_logging()
control_mode = args.mode control_mode = args.mode
robot_name = args.robot robot_name = args.robot
kwargs = vars(args) kwargs = vars(args)