Added logging for interventions to monitor the rate of interventions through time
Added an s keyboard command to force success in the case the reward classifier fails Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
b9217b06db
commit
dc086dc21f
|
@ -12,10 +12,10 @@ env:
|
||||||
|
|
||||||
wrapper:
|
wrapper:
|
||||||
crop_params_dict:
|
crop_params_dict:
|
||||||
observation.images.front: [126, 43, 329, 518]
|
observation.images.front: [102, 43, 358, 523]
|
||||||
observation.images.side: [93, 69, 381, 434]
|
observation.images.side: [92, 123, 379, 349]
|
||||||
# observation.images.front: [135, 59, 331, 527]
|
# observation.images.front: [109, 37, 361, 557]
|
||||||
# observation.images.side: [79, 47, 397, 450]
|
# observation.images.side: [94, 161, 372, 315]
|
||||||
resize_size: [128, 128]
|
resize_size: [128, 128]
|
||||||
control_time_s: 20
|
control_time_s: 20
|
||||||
reset_follower_pos: true
|
reset_follower_pos: true
|
||||||
|
|
|
@ -4,8 +4,9 @@ defaults:
|
||||||
- _self_
|
- _self_
|
||||||
|
|
||||||
seed: 13
|
seed: 13
|
||||||
dataset_repo_id: aractingi/push_cube_square_reward_cropped_resized
|
dataset_repo_id: aractingi/push_cube_square_light_reward_cropped_resized
|
||||||
dataset_root: data/aractingi/push_cube_square_reward_cropped_resized
|
# aractingi/push_cube_square_reward_1_cropped_resized
|
||||||
|
dataset_root: data/aractingi/push_cube_square_light_reward_cropped_resized
|
||||||
local_files_only: true
|
local_files_only: true
|
||||||
train_split_proportion: 0.8
|
train_split_proportion: 0.8
|
||||||
|
|
||||||
|
@ -26,7 +27,6 @@ training:
|
||||||
eval_freq: 1 # How often to run validation (in epochs)
|
eval_freq: 1 # How often to run validation (in epochs)
|
||||||
save_freq: 1 # How often to save checkpoints (in epochs)
|
save_freq: 1 # How often to save checkpoints (in epochs)
|
||||||
save_checkpoint: true
|
save_checkpoint: true
|
||||||
# image_keys: ["observation.images.top", "observation.images.wrist"]
|
|
||||||
image_keys: ["observation.images.front", "observation.images.side"]
|
image_keys: ["observation.images.front", "observation.images.side"]
|
||||||
label_key: "next.reward"
|
label_key: "next.reward"
|
||||||
profile_inference_time: false
|
profile_inference_time: false
|
||||||
|
@ -37,8 +37,8 @@ eval:
|
||||||
num_samples_to_log: 30 # Number of validation samples to log in the table
|
num_samples_to_log: 30 # Number of validation samples to log in the table
|
||||||
|
|
||||||
policy:
|
policy:
|
||||||
name: "hilserl/classifier/push_cube_square_reward_cropped_resized" #"hilserl/classifier/pick_place_lego_cube_120
|
name: "hilserl/classifier"
|
||||||
model_name: "helper2424/resnet10" # "facebook/convnext-base-224" #"helper2424/resnet10"
|
model_name: "helper2424/resnet10" # "facebook/convnext-base-224
|
||||||
model_type: "cnn"
|
model_type: "cnn"
|
||||||
num_cameras: 2 # Has to be len(training.image_keys)
|
num_cameras: 2 # Has to be len(training.image_keys)
|
||||||
|
|
||||||
|
@ -50,4 +50,4 @@ wandb:
|
||||||
|
|
||||||
device: "mps"
|
device: "mps"
|
||||||
resume: false
|
resume: false
|
||||||
output_dir: "outputs/classifier/resnet10_frozen"
|
output_dir: "outputs/classifier/old_trainer_resnet10_frozen"
|
||||||
|
|
|
@ -223,6 +223,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||||
sum_reward_episode = 0
|
sum_reward_episode = 0
|
||||||
list_transition_to_send_to_learner = []
|
list_transition_to_send_to_learner = []
|
||||||
list_policy_time = []
|
list_policy_time = []
|
||||||
|
episode_intervention = False
|
||||||
|
|
||||||
for interaction_step in range(cfg.training.online_steps):
|
for interaction_step in range(cfg.training.online_steps):
|
||||||
if interaction_step >= cfg.training.online_step_before_learning:
|
if interaction_step >= cfg.training.online_step_before_learning:
|
||||||
|
@ -252,6 +253,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||||
if info["is_intervention"]:
|
if info["is_intervention"]:
|
||||||
# TODO: Check the shape
|
# TODO: Check the shape
|
||||||
action = info["action_intervention"]
|
action = info["action_intervention"]
|
||||||
|
episode_intervention = True
|
||||||
|
|
||||||
# Check for NaN values in observations
|
# Check for NaN values in observations
|
||||||
for key, tensor in obs.items():
|
for key, tensor in obs.items():
|
||||||
|
@ -295,11 +297,13 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||||
interaction_message={
|
interaction_message={
|
||||||
"Episodic reward": sum_reward_episode,
|
"Episodic reward": sum_reward_episode,
|
||||||
"Interaction step": interaction_step,
|
"Interaction step": interaction_step,
|
||||||
|
"Episode intervention": int(episode_intervention),
|
||||||
**stats,
|
**stats,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
sum_reward_episode = 0.0
|
sum_reward_episode = 0.0
|
||||||
|
episode_intervention = False
|
||||||
obs, info = online_env.reset()
|
obs, info = online_env.reset()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -245,10 +245,18 @@ if __name__ == "__main__":
|
||||||
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
||||||
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
|
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
|
||||||
|
|
||||||
rois = select_square_roi_for_images(images)
|
# rois = select_square_roi_for_images(images)
|
||||||
|
rois = {
|
||||||
|
"observation.images.front": [102, 43, 358, 523],
|
||||||
|
"observation.images.side": [92, 123, 379, 349],
|
||||||
|
}
|
||||||
# rois = {
|
# rois = {
|
||||||
# "observation.images.front": [126, 43, 329, 518],
|
# "observation.images.side": (92, 123, 379, 349),
|
||||||
# "observation.images.side": [93, 69, 381, 434],
|
# "observation.images.front": (109, 37, 361, 557),
|
||||||
|
# }
|
||||||
|
# rois = {
|
||||||
|
# "observation.images.front": [109, 37, 361, 557],
|
||||||
|
# "observation.images.side": [94, 161, 372, 315],
|
||||||
# }
|
# }
|
||||||
|
|
||||||
# Print the selected rectangular ROIs
|
# Print the selected rectangular ROIs
|
||||||
|
|
|
@ -312,7 +312,7 @@ class RewardWrapper(gym.Wrapper):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
reward = (
|
reward = (
|
||||||
self.reward_classifier.predict_reward(images, threshold=0.5)
|
self.reward_classifier.predict_reward(images, threshold=0.6)
|
||||||
if self.reward_classifier is not None
|
if self.reward_classifier is not None
|
||||||
else 0.0
|
else 0.0
|
||||||
)
|
)
|
||||||
|
@ -507,6 +507,7 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||||
"pause_policy": False,
|
"pause_policy": False,
|
||||||
"reset_env": False,
|
"reset_env": False,
|
||||||
"human_intervention_step": False,
|
"human_intervention_step": False,
|
||||||
|
"episode_success": False,
|
||||||
}
|
}
|
||||||
self.event_lock = Lock() # Thread-safe access to events
|
self.event_lock = Lock() # Thread-safe access to events
|
||||||
self._init_keyboard_listener()
|
self._init_keyboard_listener()
|
||||||
|
@ -528,7 +529,12 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||||
if key == keyboard.Key.right or key == keyboard.Key.esc:
|
if key == keyboard.Key.right or key == keyboard.Key.esc:
|
||||||
print("Right arrow key pressed. Exiting loop...")
|
print("Right arrow key pressed. Exiting loop...")
|
||||||
self.events["exit_early"] = True
|
self.events["exit_early"] = True
|
||||||
elif key == keyboard.Key.space and not self.events["exit_early"]:
|
return
|
||||||
|
if hasattr(key, "char") and key.char == "s":
|
||||||
|
print("Key 's' pressed. Episode success triggered.")
|
||||||
|
self.events["episode_success"] = True
|
||||||
|
return
|
||||||
|
if key == keyboard.Key.space and not self.events["exit_early"]:
|
||||||
if not self.events["pause_policy"]:
|
if not self.events["pause_policy"]:
|
||||||
print(
|
print(
|
||||||
"Space key pressed. Human intervention required.\n"
|
"Space key pressed. Human intervention required.\n"
|
||||||
|
@ -536,15 +542,18 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||||
)
|
)
|
||||||
self.events["pause_policy"] = True
|
self.events["pause_policy"] = True
|
||||||
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
||||||
elif self.events["pause_policy"] and not self.events["human_intervention_step"]:
|
return
|
||||||
|
if self.events["pause_policy"] and not self.events["human_intervention_step"]:
|
||||||
self.events["human_intervention_step"] = True
|
self.events["human_intervention_step"] = True
|
||||||
print("Space key pressed. Human intervention starting.")
|
print("Space key pressed. Human intervention starting.")
|
||||||
log_say("Starting human intervention.", play_sounds=True)
|
log_say("Starting human intervention.", play_sounds=True)
|
||||||
else:
|
return
|
||||||
|
if self.events["pause_policy"] and self.events["human_intervention_step"]:
|
||||||
self.events["pause_policy"] = False
|
self.events["pause_policy"] = False
|
||||||
self.events["human_intervention_step"] = False
|
self.events["human_intervention_step"] = False
|
||||||
print("Space key pressed for a third time.")
|
print("Space key pressed for a third time.")
|
||||||
log_say("Continuing with policy actions.", play_sounds=True)
|
log_say("Continuing with policy actions.", play_sounds=True)
|
||||||
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error handling key press: {e}")
|
print(f"Error handling key press: {e}")
|
||||||
|
|
||||||
|
@ -566,7 +575,6 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||||
with self.event_lock:
|
with self.event_lock:
|
||||||
if self.events["exit_early"]:
|
if self.events["exit_early"]:
|
||||||
terminated_by_keyboard = True
|
terminated_by_keyboard = True
|
||||||
# If we need to wait for human intervention, we note that outside the lock.
|
|
||||||
pause_policy = self.events["pause_policy"]
|
pause_policy = self.events["pause_policy"]
|
||||||
|
|
||||||
if pause_policy:
|
if pause_policy:
|
||||||
|
@ -580,6 +588,13 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||||
|
|
||||||
# Execute the step in the underlying environment
|
# Execute the step in the underlying environment
|
||||||
obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
|
obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
|
||||||
|
|
||||||
|
# Override reward and termination if episode success event triggered
|
||||||
|
with self.event_lock:
|
||||||
|
if self.events["episode_success"]:
|
||||||
|
reward = 1
|
||||||
|
terminated_by_keyboard = True
|
||||||
|
|
||||||
return obs, reward, terminated or terminated_by_keyboard, truncated, info
|
return obs, reward, terminated or terminated_by_keyboard, truncated, info
|
||||||
|
|
||||||
def reset(self, **kwargs) -> Tuple[Any, Dict]:
|
def reset(self, **kwargs) -> Tuple[Any, Dict]:
|
||||||
|
|
Loading…
Reference in New Issue