Add intervention rate tracking in act_with_policy function

- Introduced counters for tracking intervention steps and total steps during training.
- Calculated and logged the intervention rate at the end of each episode.
- Reset intervention counters after each episode to ensure accurate tracking.
This commit is contained in:
AdilZouitine 2025-03-19 18:37:50 +00:00
parent eb6787e159
commit 80e766c05c
1 changed files with 16 additions and 0 deletions

View File

@ -335,6 +335,9 @@ def act_with_policy(
list_transition_to_send_to_learner = []
list_policy_time = []
episode_intervention = False
# Add counters for intervention rate calculation
episode_intervention_steps = 0
episode_total_steps = 0
for interaction_step in range(cfg.training.online_steps):
start_time = time.perf_counter()
@ -372,6 +375,8 @@ def act_with_policy(
)
sum_reward_episode += float(reward)
# Increment total steps counter for intervention rate
episode_total_steps += 1
# NOTE: We overide the action if the intervention is True, because the action applied is the intervention action
if "is_intervention" in info and info["is_intervention"]:
@ -380,6 +385,8 @@ def act_with_policy(
# but sometimes for example we want to deactivate the gripper
action = info["action_intervention"]
episode_intervention = True
# Increment intervention steps counter
episode_intervention_steps += 1
# Check for NaN values in observations
for key, tensor in obs.items():
@ -424,6 +431,11 @@ def act_with_policy(
stats = get_frequency_stats(list_policy_time)
list_policy_time.clear()
# Calculate intervention rate
intervention_rate = 0.0
if episode_total_steps > 0:
intervention_rate = episode_intervention_steps / episode_total_steps
# Send episodic reward to the learner
interactions_queue.put(
python_object_to_bytes(
@ -431,12 +443,16 @@ def act_with_policy(
"Episodic reward": sum_reward_episode,
"Interaction step": interaction_step,
"Episode intervention": int(episode_intervention),
"Intervention rate": intervention_rate,
**stats,
}
)
)
sum_reward_episode = 0.0
episode_intervention = False
# Reset intervention counters
episode_intervention_steps = 0
episode_total_steps = 0
obs, info = online_env.reset()
if cfg.fps is not None: