solved issue between keyboard listener and renderer

This commit is contained in:
Michel Aractingi 2024-10-17 10:36:06 +02:00
parent d386f50045
commit b45490874a
1 changed files with 21 additions and 13 deletions

View File

@ -254,6 +254,8 @@ def create_rl_hf_dataset(data_dict):
def teleoperate(env, robot: Robot, teleop_time_s=None, **kwargs): def teleoperate(env, robot: Robot, teleop_time_s=None, **kwargs):
env = env()
axis_directions = kwargs.get('axis_directions', [1]) axis_directions = kwargs.get('axis_directions', [1])
offsets = kwargs.get('offsets', [0]) offsets = kwargs.get('offsets', [0])
fps = kwargs.get('fps', None) fps = kwargs.get('fps', None)
@ -287,6 +289,7 @@ def record(
tags=None, tags=None,
num_image_writers_per_camera=4, num_image_writers_per_camera=4,
force_override=False, force_override=False,
visualization_mode='viewer',
**kwargs **kwargs
): ):
@ -345,6 +348,9 @@ def record(
listener = keyboard.Listener(on_press=on_press) listener = keyboard.Listener(on_press=on_press)
listener.start() listener.start()
# create env
env = env()
# Save images using threads to reach high fps (30 and more) # Save images using threads to reach high fps (30 and more)
# Using `with` to exist smoothly if an execption is raised. # Using `with` to exist smoothly if an execption is raised.
futures = [] futures = []
@ -357,14 +363,14 @@ def record(
command_queue = multiprocessing.Queue(1000) command_queue = multiprocessing.Queue(1000)
stop_reading_leader = multiprocessing.Value('i', 0) stop_reading_leader = multiprocessing.Value('i', 0)
read_leader = multiprocessing.Process(target=read_commands_from_leader, args=(robot, command_queue, fps, axis_directions, offsets, stop_reading_leader)) read_leader = multiprocessing.Process(target=read_commands_from_leader, args=(robot, command_queue, fps, axis_directions, offsets, stop_reading_leader))
if not is_headless(): if not is_headless() and visualization_mode=='observations':
observations_queue = multiprocessing.Queue(1000) observations_queue = multiprocessing.Queue(1000)
show_images = multiprocessing.Process(target=show_image_observations, args=(observations_queue, )) show_images = multiprocessing.Process(target=show_image_observations, args=(observations_queue, ))
show_images.start()
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
# Start recording all episodes # Start recording all episodes
# start reading from leader, disable stop flag in leader process # start reading from leader, disable stop flag in leader process
if not is_headless(): show_images.start()
while episode_index < num_episodes: while episode_index < num_episodes:
logging.info(f"Recording episode {episode_index}") logging.info(f"Recording episode {episode_index}")
say(f"Recording episode {episode_index}") say(f"Recording episode {episode_index}")
@ -387,14 +393,8 @@ def record(
save_image, observation[key].squeeze(0), str_key, frame_index, episode_index, videos_dir) save_image, observation[key].squeeze(0), str_key, frame_index, episode_index, videos_dir)
] ]
if not is_headless(): if not is_headless() and visualization_mode=='observations':
observations_queue.put(observation) observations_queue.put(observation)
#executor.submit(show_image_observations, observation, image_keys)
#show_image_observations(observation, image_keys)
# for key in image_keys:
# #breakpoint()
# cv2.imshow(key, cv2.cvtColor(observation[key].squeeze(0), cv2.COLOR_RGB2BGR))
# cv2.waitKey(1)
state_obs = [] state_obs = []
for key in state_keys: for key in state_keys:
@ -494,7 +494,7 @@ def record(
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images" concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
): ):
pass pass
if not is_headless(): if not is_headless() and visualization_mode=='rgb_array':
show_images.terminate() show_images.terminate()
observations_queue.close() observations_queue.close()
break break
@ -691,6 +691,13 @@ if __name__ == "__main__":
default=0, default=0,
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.", help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
) )
parser_record.add_argument(
"--visualization-mode",
type=str,
default='viewer',
choices=['viewer', 'observations'],
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
)
parser_replay = subparsers.add_parser("replay", parents=[base_parser]) parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument( parser_replay.add_argument(
@ -724,7 +731,8 @@ if __name__ == "__main__":
# make gym env # make gym env
env_cfg = init_hydra_config(env_config_path) env_cfg = init_hydra_config(env_config_path)
env = make_env(env_cfg, n_envs=1) env_cfg.env.gym.render_mode = 'human' if args.visualization_mode=='viewer' else 'rgb_array'
env_fn = lambda: make_env(env_cfg, n_envs=1)
# make robot # make robot
robot_overrides = ['~cameras', '~follower_arms'] robot_overrides = ['~cameras', '~follower_arms']
@ -734,10 +742,10 @@ if __name__ == "__main__":
kwargs.update(env_cfg.calibration) kwargs.update(env_cfg.calibration)
if control_mode == "teleoperate": if control_mode == "teleoperate":
teleoperate(env, robot, **kwargs) teleoperate(env_fn, robot, **kwargs)
elif control_mode == "record": elif control_mode == "record":
record(env, robot, **kwargs) record(env_fn, robot, **kwargs)
elif control_mode == "replay": elif control_mode == "replay":
replay(robot, **kwargs) replay(robot, **kwargs)