Refactor -> control_loop()
This commit is contained in:
parent
20f5ac31fa
commit
904aaa497c
|
@ -38,7 +38,7 @@ def safe_stop_image_writer(func):
|
|||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
image_writer = kwargs["dataset"].get("image_writer")
|
||||
image_writer = kwargs.get("dataset", {}).get("image_writer")
|
||||
if image_writer is not None:
|
||||
print("Waiting for image writer to terminate...")
|
||||
stop_image_writer(image_writer, timeout=20)
|
||||
|
|
|
@ -19,7 +19,7 @@ from lerobot.common.datasets.populate_dataset import add_frame, safe_stop_image_
|
|||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, log_say, set_global_seed
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
|
||||
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||
|
||||
|
||||
|
@ -184,45 +184,28 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides, fps):
|
|||
return policy, fps, device, use_amp
|
||||
|
||||
|
||||
def warmup_record(robot, events, enable_teloperation, warmup_time_s, display_cameras, play_sounds, fps):
|
||||
# TODO(rcadene): refactor warmup_record and reset_environment
|
||||
timestamp = 0
|
||||
start_warmup_t = time.perf_counter()
|
||||
|
||||
if warmup_time_s > 0:
|
||||
log_say("Warming up (no data recording)", play_sounds)
|
||||
|
||||
while timestamp < warmup_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if enable_teloperation:
|
||||
observation, _ = robot.teleop_step(record_data=True)
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_warmup_t
|
||||
if events is not None and events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
|
||||
@safe_stop_image_writer
|
||||
def record_episode(
|
||||
dataset,
|
||||
def warmup_record(
|
||||
robot,
|
||||
events,
|
||||
enable_teloperation,
|
||||
warmup_time_s,
|
||||
display_cameras,
|
||||
fps,
|
||||
):
|
||||
control_loop(
|
||||
robot=robot,
|
||||
control_time_s=warmup_time_s,
|
||||
display_cameras=display_cameras,
|
||||
events=events,
|
||||
fps=fps,
|
||||
teleoperate=enable_teloperation,
|
||||
)
|
||||
|
||||
|
||||
def record_episode(
|
||||
robot,
|
||||
dataset,
|
||||
events,
|
||||
episode_time_s,
|
||||
display_cameras,
|
||||
policy,
|
||||
|
@ -230,24 +213,65 @@ def record_episode(
|
|||
use_amp,
|
||||
fps,
|
||||
):
|
||||
control_loop(
|
||||
robot=robot,
|
||||
control_time_s=episode_time_s,
|
||||
display_cameras=display_cameras,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
policy=policy,
|
||||
device=device,
|
||||
use_amp=use_amp,
|
||||
fps=fps,
|
||||
teleoperate=policy is None,
|
||||
)
|
||||
|
||||
|
||||
@safe_stop_image_writer
|
||||
def control_loop(
|
||||
robot,
|
||||
control_time_s,
|
||||
teleoperate=False,
|
||||
display_cameras=False,
|
||||
dataset=None,
|
||||
events=None,
|
||||
policy=None,
|
||||
device=None,
|
||||
use_amp=None,
|
||||
fps=None,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
if events is None:
|
||||
events = {}
|
||||
|
||||
if teleoperate and policy is not None:
|
||||
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
||||
|
||||
if dataset is not None and fps is not None and dataset["fps"] != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < episode_time_s:
|
||||
while timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if policy is None:
|
||||
if teleoperate:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
pred_action = predict_action(observation, policy, device, use_amp)
|
||||
if policy is not None:
|
||||
pred_action = predict_action(observation, policy, device, use_amp)
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset.
|
||||
action = robot.send_action(pred_action)
|
||||
action = {"action": action}
|
||||
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset.
|
||||
action = robot.send_action(pred_action)
|
||||
action = {"action": action}
|
||||
|
||||
add_frame(dataset, observation, action)
|
||||
if dataset is not None:
|
||||
add_frame(dataset, observation, action)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
|
@ -262,7 +286,7 @@ def record_episode(
|
|||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
if events is not None and events["exit_early"]:
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
|
@ -282,7 +306,7 @@ def reset_environment(robot, events, reset_time_s):
|
|||
time.sleep(1)
|
||||
timestamp = time.perf_counter() - start_vencod_t
|
||||
pbar.update(1)
|
||||
if events is not None and events["exit_early"]:
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
|
@ -111,7 +111,8 @@ from lerobot.common.datasets.populate_dataset import (
|
|||
init_dataset,
|
||||
save_current_episode,
|
||||
)
|
||||
from lerobot.common.robot_devices.control_robot import (
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
control_loop,
|
||||
has_method,
|
||||
init_keyboard_listener,
|
||||
init_policy,
|
||||
|
@ -177,25 +178,16 @@ def calibrate(robot: Robot, arms: list[str] | None):
|
|||
|
||||
|
||||
@safe_disconnect
|
||||
def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | None = None):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
start_teleop_t = time.perf_counter()
|
||||
while True:
|
||||
start_loop_t = time.perf_counter()
|
||||
robot.teleop_step()
|
||||
|
||||
if fps is not None:
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
|
||||
break
|
||||
def teleoperate(
|
||||
robot: Robot, fps: int | None = None, teleop_time_s: float | None = None, display_cameras: bool = False
|
||||
):
|
||||
control_loop(
|
||||
robot,
|
||||
control_time_s=teleop_time_s,
|
||||
fps=fps,
|
||||
teleoperate=True,
|
||||
display_cameras=display_cameras,
|
||||
)
|
||||
|
||||
|
||||
@safe_disconnect
|
||||
|
@ -254,7 +246,8 @@ def record(
|
|||
# 2. give times to the robot devices to connect and start synchronizing,
|
||||
# 3. place the cameras windows on screen
|
||||
enable_teleoperation = policy is None
|
||||
warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, play_sounds, fps)
|
||||
log_say("Warmup record", play_sounds)
|
||||
warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, fps)
|
||||
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
|
@ -277,32 +270,21 @@ def record(
|
|||
fps=fps,
|
||||
)
|
||||
|
||||
# In case stop recording is requested during `record_episode`
|
||||
if events is not None and events["stop_recording"]:
|
||||
save_current_episode(dataset)
|
||||
break
|
||||
|
||||
# Execute a few seconds without recording to give time to manually reset the environment
|
||||
# Current code logic doesn't allow to teleoperate during this time.
|
||||
# TODO(rcadene): add an option to enable teleoperation during reset
|
||||
# Skip reset for the last episode to be recorded
|
||||
if episode_index < num_episodes - 1:
|
||||
if not events["stop_recording"] and (
|
||||
(episode_index < num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", play_sounds)
|
||||
reset_environment(robot, events, reset_time_s)
|
||||
|
||||
# In case stop recording is requested during `reset_environment`
|
||||
if events is not None and events["stop_recording"]:
|
||||
save_current_episode(dataset)
|
||||
break
|
||||
|
||||
if events is not None and events["rerecord_episode"]:
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
delete_current_episode(dataset)
|
||||
# Force reset
|
||||
log_say("Reset the environment", play_sounds)
|
||||
reset_environment(robot, events, reset_time_s)
|
||||
continue
|
||||
|
||||
# Increment by one dataset["current_episode_index"]
|
||||
|
@ -320,6 +302,7 @@ def record(
|
|||
def replay(
|
||||
robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True
|
||||
):
|
||||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||
# TODO(rcadene): Add option to record logs
|
||||
local_dir = Path(root) / repo_id
|
||||
if not local_dir.exists():
|
||||
|
@ -378,6 +361,12 @@ if __name__ == "__main__":
|
|||
parser_teleop.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
parser_teleop.add_argument(
|
||||
"--display-cameras",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Display all cameras on screen (set to 1 to display or 0).",
|
||||
)
|
||||
|
||||
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||
parser_record.add_argument(
|
||||
|
|
Loading…
Reference in New Issue