Refactor -> control_loop()

This commit is contained in:
Remi Cadene 2024-10-14 15:15:28 +02:00
parent 20f5ac31fa
commit 904aaa497c
3 changed files with 99 additions and 86 deletions

View File

@ -38,7 +38,7 @@ def safe_stop_image_writer(func):
try:
return func(*args, **kwargs)
except Exception as e:
image_writer = kwargs["dataset"].get("image_writer")
image_writer = kwargs.get("dataset", {}).get("image_writer")
if image_writer is not None:
print("Waiting for image writer to terminate...")
stop_image_writer(image_writer, timeout=20)

View File

@ -19,7 +19,7 @@ from lerobot.common.datasets.populate_dataset import add_frame, safe_stop_image_
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, log_say, set_global_seed
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
from lerobot.scripts.eval import get_pretrained_policy_path
@ -184,45 +184,28 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides, fps):
return policy, fps, device, use_amp
def warmup_record(robot, events, enable_teloperation, warmup_time_s, display_cameras, play_sounds, fps):
# TODO(rcadene): refactor warmup_record and reset_environment
timestamp = 0
start_warmup_t = time.perf_counter()
if warmup_time_s > 0:
log_say("Warming up (no data recording)", play_sounds)
while timestamp < warmup_time_s:
start_loop_t = time.perf_counter()
if enable_teloperation:
observation, _ = robot.teleop_step(record_data=True)
else:
observation = robot.capture_observation()
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_warmup_t
if events is not None and events["exit_early"]:
events["exit_early"] = False
break
@safe_stop_image_writer
def record_episode(
dataset,
def warmup_record(
robot,
events,
enable_teloperation,
warmup_time_s,
display_cameras,
fps,
):
control_loop(
robot=robot,
control_time_s=warmup_time_s,
display_cameras=display_cameras,
events=events,
fps=fps,
teleoperate=enable_teloperation,
)
def record_episode(
robot,
dataset,
events,
episode_time_s,
display_cameras,
policy,
@ -230,24 +213,65 @@ def record_episode(
use_amp,
fps,
):
control_loop(
robot=robot,
control_time_s=episode_time_s,
display_cameras=display_cameras,
dataset=dataset,
events=events,
policy=policy,
device=device,
use_amp=use_amp,
fps=fps,
teleoperate=policy is None,
)
@safe_stop_image_writer
def control_loop(
robot,
control_time_s,
teleoperate=False,
display_cameras=False,
dataset=None,
events=None,
policy=None,
device=None,
use_amp=None,
fps=None,
):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
robot.connect()
if events is None:
events = {}
if teleoperate and policy is not None:
raise ValueError("When `teleoperate` is True, `policy` should be None.")
if dataset is not None and fps is not None and dataset["fps"] != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < episode_time_s:
while timestamp < control_time_s:
start_loop_t = time.perf_counter()
if policy is None:
if teleoperate:
observation, action = robot.teleop_step(record_data=True)
else:
observation = robot.capture_observation()
pred_action = predict_action(observation, policy, device, use_amp)
if policy is not None:
pred_action = predict_action(observation, policy, device, use_amp)
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
action = robot.send_action(pred_action)
action = {"action": action}
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
action = robot.send_action(pred_action)
action = {"action": action}
add_frame(dataset, observation, action)
if dataset is not None:
add_frame(dataset, observation, action)
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
@ -262,7 +286,7 @@ def record_episode(
log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_episode_t
if events is not None and events["exit_early"]:
if events["exit_early"]:
events["exit_early"] = False
break
@ -282,7 +306,7 @@ def reset_environment(robot, events, reset_time_s):
time.sleep(1)
timestamp = time.perf_counter() - start_vencod_t
pbar.update(1)
if events is not None and events["exit_early"]:
if events["exit_early"]:
events["exit_early"] = False
break

View File

@ -111,7 +111,8 @@ from lerobot.common.datasets.populate_dataset import (
init_dataset,
save_current_episode,
)
from lerobot.common.robot_devices.control_robot import (
from lerobot.common.robot_devices.control_utils import (
control_loop,
has_method,
init_keyboard_listener,
init_policy,
@ -177,25 +178,16 @@ def calibrate(robot: Robot, arms: list[str] | None):
@safe_disconnect
def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | None = None):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
robot.connect()
start_teleop_t = time.perf_counter()
while True:
start_loop_t = time.perf_counter()
robot.teleop_step()
if fps is not None:
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
log_control_info(robot, dt_s, fps=fps)
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
break
def teleoperate(
robot: Robot, fps: int | None = None, teleop_time_s: float | None = None, display_cameras: bool = False
):
control_loop(
robot,
control_time_s=teleop_time_s,
fps=fps,
teleoperate=True,
display_cameras=display_cameras,
)
@safe_disconnect
@ -254,7 +246,8 @@ def record(
# 2. give times to the robot devices to connect and start synchronizing,
# 3. place the cameras windows on screen
enable_teleoperation = policy is None
warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, play_sounds, fps)
log_say("Warmup record", play_sounds)
warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, fps)
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
@ -277,32 +270,21 @@ def record(
fps=fps,
)
# In case stop recording is requested during `record_episode`
if events is not None and events["stop_recording"]:
save_current_episode(dataset)
break
# Execute a few seconds without recording to give time to manually reset the environment
# Current code logic doesn't allow to teleoperate during this time.
# TODO(rcadene): add an option to enable teleoperation during reset
# Skip reset for the last episode to be recorded
if episode_index < num_episodes - 1:
if not events["stop_recording"] and (
(episode_index < num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", play_sounds)
reset_environment(robot, events, reset_time_s)
# In case stop recording is requested during `reset_environment`
if events is not None and events["stop_recording"]:
save_current_episode(dataset)
break
if events is not None and events["rerecord_episode"]:
if events["rerecord_episode"]:
log_say("Re-record episode", play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
delete_current_episode(dataset)
# Force reset
log_say("Reset the environment", play_sounds)
reset_environment(robot, events, reset_time_s)
continue
# Increment by one dataset["current_episode_index"]
@ -320,6 +302,7 @@ def record(
def replay(
robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True
):
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs
local_dir = Path(root) / repo_id
if not local_dir.exists():
@ -378,6 +361,12 @@ if __name__ == "__main__":
parser_teleop.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_teleop.add_argument(
"--display-cameras",
type=int,
default=1,
help="Display all cameras on screen (set to 1 to display or 0).",
)
parser_record = subparsers.add_parser("record", parents=[base_parser])
parser_record.add_argument(