From 80e766c05ce5e83af239ec0e35c7892a07f7da1f Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Wed, 19 Mar 2025 18:37:50 +0000 Subject: [PATCH] Add intervention rate tracking in act_with_policy function - Introduced counters for tracking intervention steps and total steps during training. - Calculated and logged the intervention rate at the end of each episode. - Reset intervention counters after each episode to ensure accurate tracking. --- lerobot/scripts/server/actor_server.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index 45fd34a3..8264bf00 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -335,6 +335,9 @@ def act_with_policy( list_transition_to_send_to_learner = [] list_policy_time = [] episode_intervention = False + # Add counters for intervention rate calculation + episode_intervention_steps = 0 + episode_total_steps = 0 for interaction_step in range(cfg.training.online_steps): start_time = time.perf_counter() @@ -372,6 +375,8 @@ def act_with_policy( ) sum_reward_episode += float(reward) + # Increment total steps counter for intervention rate + episode_total_steps += 1 # NOTE: We overide the action if the intervention is True, because the action applied is the intervention action if "is_intervention" in info and info["is_intervention"]: @@ -380,6 +385,8 @@ def act_with_policy( # but sometimes for example we want to deactivate the gripper action = info["action_intervention"] episode_intervention = True + # Increment intervention steps counter + episode_intervention_steps += 1 # Check for NaN values in observations for key, tensor in obs.items(): @@ -424,6 +431,11 @@ def act_with_policy( stats = get_frequency_stats(list_policy_time) list_policy_time.clear() + # Calculate intervention rate + intervention_rate = 0.0 + if episode_total_steps > 0: + intervention_rate = episode_intervention_steps / episode_total_steps + # Send episodic reward to the learner interactions_queue.put( python_object_to_bytes( @@ -431,12 +443,16 @@ def act_with_policy( "Episodic reward": sum_reward_episode, "Interaction step": interaction_step, "Episode intervention": int(episode_intervention), + "Intervention rate": intervention_rate, **stats, } ) ) sum_reward_episode = 0.0 episode_intervention = False + # Reset intervention counters + episode_intervention_steps = 0 + episode_total_steps = 0 obs, info = online_env.reset() if cfg.fps is not None: