diff --git a/lerobot/scripts/control_sim_robot.py b/lerobot/scripts/control_sim_robot.py index 85d44c54..e198a814 100644 --- a/lerobot/scripts/control_sim_robot.py +++ b/lerobot/scripts/control_sim_robot.py @@ -180,6 +180,12 @@ def is_headless(): print() return True +def init_read_leader(robot, fps, **kwargs): + axis_directions = kwargs.get('axis_directions', [1]) + offsets = kwargs.get('offsets', [0]) + command_queue = multiprocessing.Queue(1000) + read_leader = multiprocessing.Process(target=read_commands_from_leader, args=(robot, command_queue, fps, axis_directions, offsets)) + return read_leader, command_queue def read_commands_from_leader(robot: Robot, queue: multiprocessing.Queue, fps: int, axis_directions: list, offsets: list, stop_flag=None): if not robot.is_connected: @@ -255,15 +261,9 @@ def create_rl_hf_dataset(data_dict): def teleoperate(env, robot: Robot, teleop_time_s=None, **kwargs): env = env() - - axis_directions = kwargs.get('axis_directions', [1]) - offsets = kwargs.get('offsets', [0]) - fps = kwargs.get('fps', None) - command_queue = multiprocessing.Queue(1000) - read_leader = multiprocessing.Process(target=read_commands_from_leader, args=(robot, command_queue, fps, axis_directions, offsets)) - env.reset() - + + read_leader, command_queue = init_read_leader(robot, **kwargs) start_teleop_t = time.perf_counter() read_leader.start() while True: @@ -356,12 +356,7 @@ def record( num_image_writers = num_image_writers_per_camera * 2 ############### num_image_writers = max(num_image_writers, 1) - # Parameters for the control - axis_directions = kwargs.get('axis_directions', [1]) - offsets = kwargs.get('offsets', [0]) - command_queue = multiprocessing.Queue(1000) - stop_reading_leader = multiprocessing.Value('i', 0) - read_leader = multiprocessing.Process(target=read_commands_from_leader, args=(robot, command_queue, fps, axis_directions, offsets, stop_reading_leader)) + read_leader, command_queue = init_read_leader(robot, fps, **kwargs) if not is_headless() and visualize_images: observations_queue = multiprocessing.Queue(1000) show_images = multiprocessing.Process(target=show_image_observations, args=(observations_queue, )) @@ -411,7 +406,6 @@ def record( timestamp = time.perf_counter() - start_episode_t if exit_early: - # If the episode is successful then break exit_early = False break @@ -422,9 +416,7 @@ def record( # stop_reading_leader is blocking command_queue.close() read_leader.terminate() - - command_queue = multiprocessing.Queue(1000) - read_leader = multiprocessing.Process(target=read_commands_from_leader, args=(robot, command_queue, fps, axis_directions, offsets, stop_reading_leader)) + read_leader, command_queue = init_read_leader(robot, fps, **kwargs) timestamp = 0 @@ -745,7 +737,7 @@ if __name__ == "__main__": record(env_fn, robot, **kwargs) elif control_mode == "replay": - replay(robot, **kwargs) + replay(env_fn, robot, **kwargs) else: raise ValueError(f"Invalid control mode: '{control_mode}', only valid modes are teleoperate, record and replay." )