Add log_control_info

This commit is contained in:
Remi Cadene 2024-07-06 17:54:22 +02:00
parent 3ff789c181
commit d83d34d9b3
4 changed files with 157 additions and 66 deletions

View File

@ -6,6 +6,10 @@ from pathlib import Path
from threading import Thread
import cv2
# Using 1 thread to avoid blocking the main thread.
# Especially useful during data collection when other threads are used
# to save the images.
cv2.setNumThreads(1)
import numpy as np
from lerobot.common.robot_devices.cameras.utils import save_color_image
@ -120,7 +124,6 @@ class OpenCVCamera:
self.camera = None
self.is_connected = False
self.threads = {}
self.results = {}
self.logs = {}

View File

@ -154,14 +154,19 @@ def get_group_sync_key(data_name, motor_names):
group_key = f"{data_name}_" + "_".join(motor_names)
return group_key
def get_thread_name(fn_name, data_name, motor_names):
def get_result_name(fn_name, data_name, motor_names):
group_key = get_group_sync_key(data_name, motor_names)
thread_name = f"{fn_name}_{group_key}"
return thread_name
rslt_name = f"{fn_name}_{group_key}"
return rslt_name
def get_queue_name(fn_name, data_name, motor_names):
group_key = get_group_sync_key(data_name, motor_names)
queue_name = f"{fn_name}_{group_key}"
return queue_name
def get_log_name(var_name, fn_name, data_name, motor_names):
thread_name = get_thread_name(fn_name, data_name, motor_names)
log_name = f"{var_name}_{thread_name}"
group_key = get_group_sync_key(data_name, motor_names)
log_name = f"{var_name}_{fn_name}_{group_key}"
return log_name
class TorqueMode(enum.Enum):
@ -197,6 +202,18 @@ class DynamixelMotorsBus:
if extra_model_control_table:
self.model_ctrl_table.update(extra_model_control_table)
self.port_handler = None
self.packet_handler = None
self.calibration = None
self.is_connected = False
self.group_readers = {}
self.group_writers = {}
self.logs = {}
def connect(self):
if self.is_connected:
raise ValueError(f"KochRobot is already connected.")
self.port_handler = PortHandler(self.port)
self.packet_handler = PacketHandler(PROTOCOL_VERSION)
@ -206,15 +223,13 @@ class DynamixelMotorsBus:
self.port_handler.setBaudRate(BAUD_RATE)
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
self.group_readers = {}
self.group_writers = {}
self.calibration = None
self.threads = {}
self.queues = {}
# for async_read and async_write
self.thread = None
self.async_read_args = {}
self.write_queue = Queue()
self.results = {}
self.logs = {}
self.is_connected = True
@property
def motor_names(self) -> list[int]:
@ -258,6 +273,9 @@ class DynamixelMotorsBus:
return values
def read(self, data_name, motor_names: list[str] | None = None):
if not self.is_connected:
raise ValueError(f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`.")
start_time = time.perf_counter()
if motor_names is None:
@ -320,6 +338,9 @@ class DynamixelMotorsBus:
return values
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
if not self.is_connected:
raise ValueError(f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`.")
start_time = time.perf_counter()
if motor_names is None:
@ -406,10 +427,20 @@ class DynamixelMotorsBus:
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names)
self.logs[ts_utc_name] = capture_timestamp_utc()
def read_loop(self, data_name, motor_names: list[str] | None = None):
def read_write_loop(self, async_read_args, write_queue):
while True:
thread_name = get_thread_name("read", data_name, motor_names)
self.results[thread_name] = self.read(data_name, motor_names)
for group_name, read_args in async_read_args.items():
self.results[group_name] = self.read(*read_args)
if write_queue.empty():
continue
write_args = write_queue.get()
if write_args is None: # A way to terminate the thread
break
self.write(*write_args)
write_queue.task_done()
def async_read(self, data_name, motor_names: list[str] | None = None):
if motor_names is None:
@ -418,32 +449,25 @@ class DynamixelMotorsBus:
if isinstance(motor_names, str):
motor_names = [motor_names]
thread_name = get_thread_name("read", data_name, motor_names)
if self.thread is None:
self.thread = Thread(target=self.read_write_loop, args=(self.async_read_args, self.write_queue))
self.thread.daemon = True
self.thread.start()
if thread_name not in self.threads:
self.threads[thread_name] = Thread(target=self.read_loop, args=(data_name, motor_names))
self.threads[thread_name].daemon = True
self.threads[thread_name].start()
group_name = get_group_sync_key(data_name, motor_names)
self.async_read_args[group_name] = (data_name, motor_names)
FPS = 200
num_tries = 0
while thread_name not in self.results:
while group_name not in self.results:
num_tries += 1
time.sleep(1 / FPS)
if num_tries > FPS:
if self.threads[thread_name].ident is None and not self.threads[thread_name].is_alive():
if self.thread.ident is None and not self.thread.is_alive():
raise Exception(f"The thread responsible for `self.async_read({data_name}, {motor_names})` took too much time to start. There might be an issue. Verify that `self.threads[thread_name].start()` has been called.")
# ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names)
return self.results[thread_name] #, self.logs[ts_utc_name]
def write_loop(self, data_name, queue: Queue, motor_names: list[str] | None = None):
while True:
values = queue.get()
if values is None: # A way to terminate the thread
break
self.write(data_name, values, motor_names)
queue.task_done()
return self.results[group_name] #, self.logs[ts_utc_name]
def async_write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
if motor_names is None:
@ -457,38 +481,33 @@ class DynamixelMotorsBus:
values = np.array(values)
thread_name = get_thread_name("write", data_name, motor_names)
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names)
if self.thread is None:
self.thread = Thread(target=self.read_write_loop, args=(self.async_read_args, self.write_queue))
self.thread.daemon = True
self.thread.start()
if thread_name not in self.threads:
self.queues[thread_name] = Queue()
self.threads[thread_name] = Thread(target=self.write_loop, args=(data_name, self.queues[thread_name], motor_names))
self.threads[thread_name].daemon = True
self.threads[thread_name].start()
self.queues[thread_name].put(values)
self.write_queue.put((data_name, values, motor_names))
FPS = 200
num_tries = 0
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names)
while ts_utc_name not in self.logs:
num_tries += 1
time.sleep(1 / FPS)
if num_tries > FPS:
if self.threads[thread_name].ident is None and not self.threads[thread_name].is_alive():
if self.thread.ident is None and not self.thread.is_alive():
raise Exception(f"The thread responsible for `self.async_write({data_name}, {values}, {motor_names})` took too much time to start. There might be an issue. Verify that `self.threads[thread_name].start()` has been called.")
return self.logs[ts_utc_name]
def __del__(self):
for thread_name in self.queues:
# Send value that corresponds to `break` logic
self.queues[thread_name].put(None)
self.queues[thread_name].join()
# Send value that corresponds to `break` logic
# if self.queue is not None:
# self.queue.put(None)
# self.queue.join()
for thread_name in self.queues:
self.threads[thread_name].join()
# TODO(rcadene): find a simple way to exit threads created by async_read
if self.thread is not None:
self.thread.join()
# def read(self, data_name, motor_name: str):
# motor_idx, model = self.motors[motor_name]

View File

@ -1,6 +1,7 @@
import pickle
from dataclasses import dataclass, field, replace
from pathlib import Path
import time
import numpy as np
import torch
@ -243,11 +244,20 @@ class KochRobot:
self.leader_arms = self.config.leader_arms
self.follower_arms = self.config.follower_arms
self.cameras = self.config.cameras
self.is_connected = False
self.logs = {}
self.async_read = False
self.async_write = False
def init_teleop(self):
def connect(self):
if self.is_connected:
raise ValueError(f"KochRobot is already connected.")
for name in self.follower_arms:
self.follower_arms[name].connect()
self.leader_arms[name].connect()
if self.calibration_path.exists():
# Reset all arms before setting calibration
for name in self.follower_arms:
@ -279,7 +289,12 @@ class KochRobot:
for name in self.cameras:
self.cameras[name].connect()
self.is_connected = True
def run_calibration(self):
if not self.is_connected:
raise ValueError(f"KochRobot is not connected. You need to run `robot.connect()`.")
calibration = {}
for name in self.follower_arms:
@ -301,13 +316,18 @@ class KochRobot:
def teleop_step(
self, record_data=False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
if not self.is_connected:
raise ValueError(f"KochRobot is not connected. You need to run `robot.connect()`.")
# Prepare to assign the positions of the leader to the follower
leader_pos = {}
for name in self.leader_arms:
now = time.perf_counter()
if self.async_read:
leader_pos[name] = self.leader_arms[name].async_read("Present_Position")
else:
leader_pos[name] = self.leader_arms[name].read("Present_Position")
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - now
follower_goal_pos = {}
for name in self.leader_arms:
@ -315,10 +335,12 @@ class KochRobot:
# Send action
for name in self.follower_arms:
now = time.perf_counter()
if self.async_write:
self.follower_arms[name].async_write("Goal_Position", follower_goal_pos[name])
else:
self.follower_arms[name].write("Goal_Position", follower_goal_pos[name])
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - now
# Early exit when recording data is not requested
if not record_data:
@ -327,10 +349,12 @@ class KochRobot:
# Read follower position
follower_pos = {}
for name in self.follower_arms:
now = time.perf_counter()
if self.async_read:
follower_pos[name] = self.follower_arms[name].async_read("Present_Position")
else:
follower_pos[name] = self.follower_arms[name].read("Present_Position")
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - now
# Create state by concatenating follower current position
state = []
@ -349,7 +373,10 @@ class KochRobot:
# Capture images from cameras
images = {}
for name in self.cameras:
now = time.perf_counter()
images[name] = self.cameras[name].async_read()
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - now
# Populate output dictionnaries and format to pytorch
obs_dict, action_dict = {}, {}
@ -361,6 +388,9 @@ class KochRobot:
return obs_dict, action_dict
def capture_observation(self):
if not self.is_connected:
raise ValueError(f"KochRobot is not connected. You need to run `robot.connect()`.")
# Read follower position
follower_pos = {}
for name in self.follower_arms:
@ -389,6 +419,9 @@ class KochRobot:
return obs_dict
def send_action(self, action):
if not self.is_connected:
raise ValueError(f"KochRobot is not connected. You need to run `robot.connect()`.")
from_idx = 0
to_idx = 0
follower_goal_pos = {}

View File

@ -63,6 +63,7 @@ python lerobot/scripts/control_robot.py run_policy \
import argparse
import concurrent.futures
import logging
import os
import shutil
import time
@ -82,7 +83,7 @@ from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.scripts.eval import get_pretrained_policy_path
from lerobot.scripts.push_dataset_to_hub import save_meta_data
@ -97,7 +98,6 @@ def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
path.parent.mkdir(parents=True, exist_ok=True)
img.save(str(path), quality=100)
def busy_wait(seconds):
# Significantly more accurate than `time.sleep`, and mendatory for our use case,
# but it consumes CPU cycles.
@ -106,12 +106,41 @@ def busy_wait(seconds):
while time.perf_counter() < end_time:
pass
def none_or_int(value):
if value == "None":
return None
return int(value)
def log_control_info(robot, dt_s, episode_index=None, frame_index=None):
log_items = []
if episode_index is not None:
log_items += [f"ep:{episode_index}"]
if frame_index is not None:
log_items += [f"frame:{frame_index}"]
# total step time displayed in milliseconds and its frequency
log_items += [f"dt:{dt_s * 1000:5.2f}={1/ dt_s:3.1f}hz"]
for name in robot.leader_arms:
read_dt_s = robot.logs[f'read_leader_{name}_pos_dt_s']
log_items += [
f"dtRlead{name[0]}:{read_dt_s * 1000:5.2f}={1/ read_dt_s:3.1f}hz",
]
for name in robot.follower_arms:
write_dt_s = robot.logs[f'write_follower_{name}_goal_pos_dt_s']
read_dt_s = robot.logs[f'read_follower_{name}_pos_dt_s']
log_items += [
f"dtRfoll{name[0]}:{write_dt_s * 1000:5.2f}={1/ write_dt_s:3.1f}hz",
f"dtWfoll{name[0]}:{read_dt_s * 1000:5.2f}={1/ read_dt_s:3.1f}hz",
]
for name in robot.cameras:
read_dt_s = robot.logs[f"read_camera_{name}_dt_s"]
async_read_dt_s = robot.logs[f"async_read_camera_{name}_dt_s"]
log_items += [
f"dtRcam{name[0]}:{read_dt_s * 1000:5.2f}={1/read_dt_s:3.1f}hz",
f"dtARcam{name[0]}:{async_read_dt_s * 1000:5.2f}={1/async_read_dt_s:3.1f}hz",
]
logging.info(" ".join(log_items))
########################################################################################
# Control modes
@ -130,7 +159,7 @@ def teleoperate(robot: Robot, fps: int | None = None):
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
log_control_info(robot, dt_s)
def record_dataset(
@ -157,15 +186,18 @@ def record_dataset(
videos_dir.mkdir(parents=True, exist_ok=True)
# Save images using threads to reach high fps (30 and more)
# Using `with` ensures the program exists smoothly if an execption is raised.
with concurrent.futures.ThreadPoolExecutor() as executor:
# Using `with` to exist smoothly if an execption is raised.
# Using only 4 worker threads to avoid blocking the main thread.
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
# Execute a few seconds without recording data, to give times
# to the robot devices to connect and start synchronizing.
timestamp = 0
start_time = time.perf_counter()
is_warmup_print = False
while timestamp < warmup_time_s:
if not is_warmup_print:
print("Warming up by skipping frames")
logging.info("Warming up (no data recording)")
os.system('say "Warmup" &')
is_warmup_print = True
@ -176,10 +208,11 @@ def record_dataset(
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f} (Warmup)")
log_control_info(robot, dt_s)
timestamp = time.perf_counter() - start_time
# Start recording all episodes
ep_dicts = []
for episode_index in range(num_episodes):
ep_dict = {}
@ -189,7 +222,7 @@ def record_dataset(
is_record_print = False
while timestamp < episode_time_s:
if not is_record_print:
print(f"Recording episode {episode_index}")
logging.info(f"Recording episode {episode_index}")
os.system(f'say "Recording episode {episode_index}" &')
is_record_print = True
@ -218,11 +251,11 @@ def record_dataset(
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
log_control_info(robot, dt_s)
timestamp = time.perf_counter() - start_time
print("Encoding images to videos")
logging.info("Encoding images to videos")
num_frames = frame_index
@ -232,6 +265,7 @@ def record_dataset(
video_path = local_dir / "videos" / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)
# TODO(rcadene): uncomment?
# clean temporary images directory
# shutil.rmtree(tmp_imgs_dir)
@ -304,7 +338,7 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
robot.init_teleop()
print("Replaying episode")
logging.info("Replaying episode")
os.system('say "Replaying episode"')
for idx in range(from_idx, to_idx):
@ -317,7 +351,7 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
log_control_info(robot, dt_s)
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
@ -349,7 +383,7 @@ def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
log_control_info(robot, dt_s)
if __name__ == "__main__":
@ -406,6 +440,8 @@ if __name__ == "__main__":
)
args = parser.parse_args()
init_logging()
control_mode = args.mode
robot_name = args.robot
kwargs = vars(args)