- Added JointMaskingActionSpace wrapper in `gym_manipulator` in order to select which joints will be controlled. For example, we can disable the gripper actions for some tasks.

- Added Nan detection mechanisms in the actor, learner and gym_manipulator for the case where we encounter nans in the loop.
- changed the non-blocking in the `.to(device)` functions to only work for the case of cuda because they were causing nans when running the policy on mps
- Added some joint clipping and limits in the env, robot and policy configs. TODO clean this part and make the limits in one config file only.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi 2025-02-11 11:34:46 +01:00
parent b5f89439ff
commit a7db3959f5
9 changed files with 161 additions and 31 deletions

View File

@ -145,8 +145,8 @@ class Classifier(
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
def predict_reward(self, x):
def predict_reward(self, x, threshold=0.6):
if self.config.num_classes == 2:
return (self.forward(x).probabilities > 0.6).float()
return (self.forward(x).probabilities > threshold).float()
else:
return torch.argmax(self.forward(x).probabilities, dim=1)

View File

@ -18,10 +18,12 @@ env:
control_time_s: 20
reset_follower_pos: true
use_relative_joint_positions: true
reset_time_s: 10
reset_time_s: 5
display_cameras: false
delta_action: 0.1
joint_masking_action_space: [1, 1, 1, 1, 0, 0] # disable wrist and gripper
reward_classifier:
pretrained_path: outputs/classifier/checkpoints/best/pretrained_model
config_path: lerobot/configs/policy/hilserl_classifier.yaml

View File

@ -20,7 +20,7 @@ training:
lr: 3e-4
eval_freq: 2500
log_freq: 500
log_freq: 1
save_freq: 2000000
online_steps: 1000000
@ -31,7 +31,7 @@ training:
online_env_seed: 10000
online_buffer_capacity: 1000000
online_buffer_seed_size: 0
online_step_before_learning: 100 #5000
online_step_before_learning: 1000 #5000
do_online_rollout_async: false
policy_update_freq: 1
@ -76,8 +76,10 @@ policy:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
observation.state:
min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274]
max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685]
# min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
# max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
output_normalization_modes:
action: min_max

View File

@ -15,8 +15,13 @@ calibration_dir: .cache/calibration/so100
# the number of motors in your follower arms.
max_relative_target: null
joint_position_relative_bounds:
min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,
0.53691274]
max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266,
0.60402685]
# min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
# max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
leader_arms:
main:

View File

@ -101,10 +101,14 @@ class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
message = message_queue.get(block=True)
if message.transition is not None:
transition_to_send_to_learner = [
move_transition_to_device(T, device="cpu") for T in message.transition
transition_to_send_to_learner: list[Transition] = [
move_transition_to_device(transition=T, device="cpu") for T in message.transition
]
# Check for NaNs in transitions before sending to learner
for transition in transition_to_send_to_learner:
for key, value in transition["state"].items():
if torch.isnan(value).any():
logging.warning(f"Found NaN values in transition {key}")
buf = io.BytesIO()
torch.save(transition_to_send_to_learner, buf)
transition_bytes = buf.getvalue()
@ -226,7 +230,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
with TimerManager(
elapsed_time_list=list_policy_time, label="Policy inference time", log=False
) as timer: # noqa: F841
action = policy.select_action(batch=obs) * 0.0
action = policy.select_action(batch=obs)
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
@ -238,7 +242,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
next_obs, reward, done, truncated, info = online_env.step(action)
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
action = torch.from_numpy(action[0]).to(device, non_blocking=True).unsqueeze(dim=0)
action = (
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
)
sum_reward_episode += float(reward)
@ -247,6 +253,11 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
# TODO: Check the shape
action = info["action_intervention"]
# Check for NaN values in observations
for key, tensor in obs.items():
if torch.isnan(tensor).any():
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
list_transition_to_send_to_learner.append(
Transition(
state=obs,

View File

@ -43,16 +43,27 @@ class BatchTransition(TypedDict):
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
# Move state tensors to CPU
transition["state"] = {key: val.to(device, non_blocking=True) for key, val in transition["state"].items()}
device = torch.device(device)
transition["state"] = {
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
}
# Move action to CPU
transition["action"] = transition["action"].to(device, non_blocking=True)
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
# No need to move reward or done, as they are float and bool
# No need to move reward or done, as they are float and bool
if isinstance(transition["reward"], torch.Tensor):
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
if isinstance(transition["done"], torch.Tensor):
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
# Move next_state tensors to CPU
transition["next_state"] = {
key: val.to(device, non_blocking=True) for key, val in transition["next_state"].items()
key: val.to(device, non_blocking=device.type == "cuda")
for key, val in transition["next_state"].items()
}
# If complementary_info is present, move its tensors to CPU

View File

@ -40,6 +40,7 @@ def find_joint_bounds(
min = np.min(np.stack(pos_list), 0)
print(f"Max angle position per joint {max}")
print(f"Min angle position per joint {min}")
break
if __name__ == "__main__":

View File

@ -296,12 +296,17 @@ class RewardWrapper(gym.Wrapper):
# NOTE: We got 15% speedup by compiling the model
self.reward_classifier = torch.compile(reward_classifier)
if isinstance(device, str):
device = torch.device(device)
self.device = device
def step(self, action):
observation, _, terminated, truncated, info = self.env.step(action)
images = [
observation[key].to(self.device, non_blocking=True) for key in observation if "image" in key
observation[key].to(self.device, non_blocking=self.device.type == "cuda")
for key in observation
if "image" in key
]
start_time = time.perf_counter()
with torch.inference_mode():
@ -309,12 +314,76 @@ class RewardWrapper(gym.Wrapper):
self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0
)
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
if reward == 1.0:
terminated = True
return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
return self.env.reset(seed=seed, options=options)
class JointMaskingActionSpace(gym.ActionWrapper):
def __init__(self, env, mask):
"""
Wrapper to mask out dimensions of the action space.
Args:
env: The environment to wrap
mask: Binary mask array where 0 indicates dimensions to remove
"""
super().__init__(env)
# Validate mask matches action space
# Keep only dimensions where mask is 1
self.active_dims = np.where(mask)[0]
if isinstance(env.action_space, gym.spaces.Box):
if len(mask) != env.action_space.shape[0]:
raise ValueError("Mask length must match action space dimensions")
low = env.action_space.low[self.active_dims]
high = env.action_space.high[self.active_dims]
self.action_space = gym.spaces.Box(low=low, high=high, dtype=env.action_space.dtype)
if isinstance(env.action_space, gym.spaces.Tuple):
if len(mask) != env.action_space[0].shape[0]:
raise ValueError("Mask length must match action space 0 dimensions")
low = env.action_space[0].low[self.active_dims]
high = env.action_space[0].high[self.active_dims]
action_space_masked = gym.spaces.Box(low=low, high=high, dtype=env.action_space[0].dtype)
self.action_space = gym.spaces.Tuple((action_space_masked, env.action_space[1]))
# Create new action space with masked dimensions
def action(self, action):
"""
Convert masked action back to full action space.
Args:
action: Action in masked space. For Tuple spaces, the first element is masked.
Returns:
Action in original space with masked dims set to 0.
"""
# Determine whether we are handling a Tuple space or a Box.
if isinstance(self.env.action_space, gym.spaces.Tuple):
# Extract the masked component from the tuple.
masked_action = action[0] if isinstance(action, tuple) else action
# Create a full action for the Box element.
full_box_action = np.zeros(self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype)
full_box_action[self.active_dims] = masked_action
# Return a tuple with the reconstructed Box action and the unchanged remainder.
return (full_box_action, action[1])
else:
# For Box action spaces.
masked_action = action if not isinstance(action, tuple) else action[0]
full_action = np.zeros(self.env.action_space.shape, dtype=self.env.action_space.dtype)
full_action[self.active_dims] = masked_action
return full_action
class TimeLimitWrapper(gym.Wrapper):
def __init__(self, env, control_time_s, fps):
self.env = env
@ -331,13 +400,12 @@ class TimeLimitWrapper(gym.Wrapper):
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
time_since_last_step = time.perf_counter() - self.last_timestamp
# logging.warning(f"Current timestep is lower than the expected fps {self.fps}")
self.episode_time_in_s += time_since_last_step
self.last_timestamp = time.perf_counter()
self.current_step += 1
# check if last timestep took more time than the expected fps
# if 1.0 / time_since_last_step < self.fps:
# logging.warning(f"Current timestep exceeded expected fps {self.fps}")
if 1.0 / time_since_last_step < self.fps:
logging.debug(f"Current timestep exceeded expected fps {self.fps}")
if self.episode_time_in_s > self.control_time_s:
# if self.current_step >= self.max_episode_steps:
@ -360,7 +428,7 @@ class ImageCropResizeWrapper(gym.Wrapper):
print(f"obs_keys , {self.env.observation_space}")
print(f"crop params dict {crop_params_dict.keys()}")
for key_crop in crop_params_dict:
if key_crop not in self.env.observation_space.keys():
if key_crop not in self.env.observation_space.keys(): # noqa: SIM118
raise ValueError(f"Key {key_crop} not in observation space")
for key in crop_params_dict:
top, left, height, width = crop_params_dict[key]
@ -375,14 +443,23 @@ class ImageCropResizeWrapper(gym.Wrapper):
obs, reward, terminated, truncated, info = self.env.step(action)
for k in self.crop_params_dict:
device = obs[k].device
# Check for NaNs before processing
if torch.isnan(obs[k]).any():
logging.error(f"NaN values detected in observation {k} before crop and resize")
if device == torch.device("mps:0"):
obs[k] = obs[k].cpu()
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
obs[k] = F.resize(obs[k], self.resize_size)
# Check for NaNs after processing
if torch.isnan(obs[k]).any():
logging.error(f"NaN values detected in observation {k} after crop and resize")
obs[k] = obs[k].to(device)
# print(f"observation with key {k} with size {obs[k].size()}")
# cv2.imshow(k, cv2.cvtColor(obs[k].cpu().squeeze(0).permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR))
# cv2.waitKey(1)
return obs, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
@ -400,12 +477,18 @@ class ImageCropResizeWrapper(gym.Wrapper):
class ConvertToLeRobotObservation(gym.ObservationWrapper):
def __init__(self, env, device):
super().__init__(env)
if isinstance(device, str):
device = torch.device(device)
self.device = device
def observation(self, observation):
observation = preprocess_observation(observation)
observation = {key: observation[key].to(self.device, non_blocking=True) for key in observation}
observation = {
key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
for key in observation
}
observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()}
return observation
@ -440,7 +523,7 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
if key == keyboard.Key.right or key == keyboard.Key.esc:
print("Right arrow key pressed. Exiting loop...")
self.events["exit_early"] = True
elif key == keyboard.Key.space:
elif key == keyboard.Key.space and not self.events["exit_early"]:
if not self.events["pause_policy"]:
print(
"Space key pressed. Human intervention required.\n"
@ -587,6 +670,7 @@ def make_robot_env(
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
env = KeyboardInterfaceWrapper(env=env)
env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.wrapper.reset_time_s)
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
env = BatchCompitableWrapper(env=env)
return env
@ -596,7 +680,7 @@ def make_robot_env(
def get_classifier(pretrained_path, config_path, device="mps"):
if pretrained_path is None or config_path is None:
return
return None
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig

View File

@ -278,12 +278,23 @@ def learner_push_parameters(
torch.save(params_dict, buf)
params_bytes = buf.getvalue()
# Push them to the Actors "SendParameters" method
# Push them to the Actor's "SendParameters" method
logging.info("[LEARNER] Publishing parameters to the Actor")
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) # noqa: F841
time.sleep(seconds_between_pushes)
def check_nan_in_transition(observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor):
for k in observations:
if torch.isnan(observations[k]).any():
logging.error(f"observations[{k}] contains NaN values")
for k in next_state:
if torch.isnan(next_state[k]).any():
logging.error(f"next_state[{k}] contains NaN values")
if torch.isnan(actions).any():
logging.error("actions contains NaN values")
def add_actor_information_and_train(
cfg,
device: str,
@ -372,6 +383,7 @@ def add_actor_information_and_train(
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
with policy_lock:
loss_critic = policy.compute_loss_critic(
@ -399,6 +411,8 @@ def add_actor_information_and_train(
next_observations = batch["next_state"]
done = batch["done"]
assert_and_breakpoint(observations=observations, actions=actions, next_state=next_observations)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@ -497,8 +511,8 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
It also initializes a learning rate scheduler, though currently, it is set to `None`.
**NOTE:**
- If the encoder is shared, its parameters are excluded from the actors optimization process.
- The policys log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
Args:
cfg: Configuration object containing hyperparameters.