- Added JointMaskingActionSpace wrapper in `gym_manipulator` in order to select which joints will be controlled. For example, we can disable the gripper actions for some tasks.
- Added Nan detection mechanisms in the actor, learner and gym_manipulator for the case where we encounter nans in the loop. - changed the non-blocking in the `.to(device)` functions to only work for the case of cuda because they were causing nans when running the policy on mps - Added some joint clipping and limits in the env, robot and policy configs. TODO clean this part and make the limits in one config file only. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
b5f89439ff
commit
a7db3959f5
|
@ -145,8 +145,8 @@ class Classifier(
|
||||||
|
|
||||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||||
|
|
||||||
def predict_reward(self, x):
|
def predict_reward(self, x, threshold=0.6):
|
||||||
if self.config.num_classes == 2:
|
if self.config.num_classes == 2:
|
||||||
return (self.forward(x).probabilities > 0.6).float()
|
return (self.forward(x).probabilities > threshold).float()
|
||||||
else:
|
else:
|
||||||
return torch.argmax(self.forward(x).probabilities, dim=1)
|
return torch.argmax(self.forward(x).probabilities, dim=1)
|
||||||
|
|
|
@ -18,10 +18,12 @@ env:
|
||||||
control_time_s: 20
|
control_time_s: 20
|
||||||
reset_follower_pos: true
|
reset_follower_pos: true
|
||||||
use_relative_joint_positions: true
|
use_relative_joint_positions: true
|
||||||
reset_time_s: 10
|
reset_time_s: 5
|
||||||
display_cameras: false
|
display_cameras: false
|
||||||
delta_action: 0.1
|
delta_action: 0.1
|
||||||
|
joint_masking_action_space: [1, 1, 1, 1, 0, 0] # disable wrist and gripper
|
||||||
|
|
||||||
reward_classifier:
|
reward_classifier:
|
||||||
pretrained_path: outputs/classifier/checkpoints/best/pretrained_model
|
pretrained_path: outputs/classifier/checkpoints/best/pretrained_model
|
||||||
config_path: lerobot/configs/policy/hilserl_classifier.yaml
|
config_path: lerobot/configs/policy/hilserl_classifier.yaml
|
||||||
|
|
|
@ -20,7 +20,7 @@ training:
|
||||||
lr: 3e-4
|
lr: 3e-4
|
||||||
|
|
||||||
eval_freq: 2500
|
eval_freq: 2500
|
||||||
log_freq: 500
|
log_freq: 1
|
||||||
save_freq: 2000000
|
save_freq: 2000000
|
||||||
|
|
||||||
online_steps: 1000000
|
online_steps: 1000000
|
||||||
|
@ -31,7 +31,7 @@ training:
|
||||||
online_env_seed: 10000
|
online_env_seed: 10000
|
||||||
online_buffer_capacity: 1000000
|
online_buffer_capacity: 1000000
|
||||||
online_buffer_seed_size: 0
|
online_buffer_seed_size: 0
|
||||||
online_step_before_learning: 100 #5000
|
online_step_before_learning: 1000 #5000
|
||||||
do_online_rollout_async: false
|
do_online_rollout_async: false
|
||||||
policy_update_freq: 1
|
policy_update_freq: 1
|
||||||
|
|
||||||
|
@ -76,8 +76,10 @@ policy:
|
||||||
mean: [0.485, 0.456, 0.406]
|
mean: [0.485, 0.456, 0.406]
|
||||||
std: [0.229, 0.224, 0.225]
|
std: [0.229, 0.224, 0.225]
|
||||||
observation.state:
|
observation.state:
|
||||||
min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
|
min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274]
|
||||||
max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
|
max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685]
|
||||||
|
# min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
|
||||||
|
# max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
|
||||||
|
|
||||||
output_normalization_modes:
|
output_normalization_modes:
|
||||||
action: min_max
|
action: min_max
|
||||||
|
|
|
@ -15,8 +15,13 @@ calibration_dir: .cache/calibration/so100
|
||||||
# the number of motors in your follower arms.
|
# the number of motors in your follower arms.
|
||||||
max_relative_target: null
|
max_relative_target: null
|
||||||
joint_position_relative_bounds:
|
joint_position_relative_bounds:
|
||||||
min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
|
min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,
|
||||||
max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
|
0.53691274]
|
||||||
|
max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266,
|
||||||
|
0.60402685]
|
||||||
|
|
||||||
|
# min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
|
||||||
|
# max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
|
||||||
|
|
||||||
leader_arms:
|
leader_arms:
|
||||||
main:
|
main:
|
||||||
|
|
|
@ -101,10 +101,14 @@ class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
|
||||||
message = message_queue.get(block=True)
|
message = message_queue.get(block=True)
|
||||||
|
|
||||||
if message.transition is not None:
|
if message.transition is not None:
|
||||||
transition_to_send_to_learner = [
|
transition_to_send_to_learner: list[Transition] = [
|
||||||
move_transition_to_device(T, device="cpu") for T in message.transition
|
move_transition_to_device(transition=T, device="cpu") for T in message.transition
|
||||||
]
|
]
|
||||||
|
# Check for NaNs in transitions before sending to learner
|
||||||
|
for transition in transition_to_send_to_learner:
|
||||||
|
for key, value in transition["state"].items():
|
||||||
|
if torch.isnan(value).any():
|
||||||
|
logging.warning(f"Found NaN values in transition {key}")
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
torch.save(transition_to_send_to_learner, buf)
|
torch.save(transition_to_send_to_learner, buf)
|
||||||
transition_bytes = buf.getvalue()
|
transition_bytes = buf.getvalue()
|
||||||
|
@ -226,7 +230,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||||
with TimerManager(
|
with TimerManager(
|
||||||
elapsed_time_list=list_policy_time, label="Policy inference time", log=False
|
elapsed_time_list=list_policy_time, label="Policy inference time", log=False
|
||||||
) as timer: # noqa: F841
|
) as timer: # noqa: F841
|
||||||
action = policy.select_action(batch=obs) * 0.0
|
action = policy.select_action(batch=obs)
|
||||||
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
||||||
|
|
||||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||||
|
@ -238,7 +242,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||||
|
|
||||||
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
||||||
action = torch.from_numpy(action[0]).to(device, non_blocking=True).unsqueeze(dim=0)
|
action = (
|
||||||
|
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
|
||||||
|
)
|
||||||
|
|
||||||
sum_reward_episode += float(reward)
|
sum_reward_episode += float(reward)
|
||||||
|
|
||||||
|
@ -247,6 +253,11 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||||
# TODO: Check the shape
|
# TODO: Check the shape
|
||||||
action = info["action_intervention"]
|
action = info["action_intervention"]
|
||||||
|
|
||||||
|
# Check for NaN values in observations
|
||||||
|
for key, tensor in obs.items():
|
||||||
|
if torch.isnan(tensor).any():
|
||||||
|
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
|
||||||
|
|
||||||
list_transition_to_send_to_learner.append(
|
list_transition_to_send_to_learner.append(
|
||||||
Transition(
|
Transition(
|
||||||
state=obs,
|
state=obs,
|
||||||
|
|
|
@ -43,16 +43,27 @@ class BatchTransition(TypedDict):
|
||||||
|
|
||||||
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
||||||
# Move state tensors to CPU
|
# Move state tensors to CPU
|
||||||
transition["state"] = {key: val.to(device, non_blocking=True) for key, val in transition["state"].items()}
|
device = torch.device(device)
|
||||||
|
transition["state"] = {
|
||||||
|
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
|
||||||
|
}
|
||||||
|
|
||||||
# Move action to CPU
|
# Move action to CPU
|
||||||
transition["action"] = transition["action"].to(device, non_blocking=True)
|
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
|
||||||
|
|
||||||
# No need to move reward or done, as they are float and bool
|
# No need to move reward or done, as they are float and bool
|
||||||
|
|
||||||
|
# No need to move reward or done, as they are float and bool
|
||||||
|
if isinstance(transition["reward"], torch.Tensor):
|
||||||
|
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
|
||||||
|
|
||||||
|
if isinstance(transition["done"], torch.Tensor):
|
||||||
|
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
|
||||||
|
|
||||||
# Move next_state tensors to CPU
|
# Move next_state tensors to CPU
|
||||||
transition["next_state"] = {
|
transition["next_state"] = {
|
||||||
key: val.to(device, non_blocking=True) for key, val in transition["next_state"].items()
|
key: val.to(device, non_blocking=device.type == "cuda")
|
||||||
|
for key, val in transition["next_state"].items()
|
||||||
}
|
}
|
||||||
|
|
||||||
# If complementary_info is present, move its tensors to CPU
|
# If complementary_info is present, move its tensors to CPU
|
||||||
|
|
|
@ -40,6 +40,7 @@ def find_joint_bounds(
|
||||||
min = np.min(np.stack(pos_list), 0)
|
min = np.min(np.stack(pos_list), 0)
|
||||||
print(f"Max angle position per joint {max}")
|
print(f"Max angle position per joint {max}")
|
||||||
print(f"Min angle position per joint {min}")
|
print(f"Min angle position per joint {min}")
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -296,12 +296,17 @@ class RewardWrapper(gym.Wrapper):
|
||||||
|
|
||||||
# NOTE: We got 15% speedup by compiling the model
|
# NOTE: We got 15% speedup by compiling the model
|
||||||
self.reward_classifier = torch.compile(reward_classifier)
|
self.reward_classifier = torch.compile(reward_classifier)
|
||||||
|
|
||||||
|
if isinstance(device, str):
|
||||||
|
device = torch.device(device)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
observation, _, terminated, truncated, info = self.env.step(action)
|
observation, _, terminated, truncated, info = self.env.step(action)
|
||||||
images = [
|
images = [
|
||||||
observation[key].to(self.device, non_blocking=True) for key in observation if "image" in key
|
observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
||||||
|
for key in observation
|
||||||
|
if "image" in key
|
||||||
]
|
]
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
@ -309,12 +314,76 @@ class RewardWrapper(gym.Wrapper):
|
||||||
self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0
|
self.reward_classifier.predict_reward(images) if self.reward_classifier is not None else 0.0
|
||||||
)
|
)
|
||||||
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
|
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
|
||||||
|
|
||||||
|
if reward == 1.0:
|
||||||
|
terminated = True
|
||||||
return observation, reward, terminated, truncated, info
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
def reset(self, seed=None, options=None):
|
def reset(self, seed=None, options=None):
|
||||||
return self.env.reset(seed=seed, options=options)
|
return self.env.reset(seed=seed, options=options)
|
||||||
|
|
||||||
|
|
||||||
|
class JointMaskingActionSpace(gym.ActionWrapper):
|
||||||
|
def __init__(self, env, mask):
|
||||||
|
"""
|
||||||
|
Wrapper to mask out dimensions of the action space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to wrap
|
||||||
|
mask: Binary mask array where 0 indicates dimensions to remove
|
||||||
|
"""
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
# Validate mask matches action space
|
||||||
|
|
||||||
|
# Keep only dimensions where mask is 1
|
||||||
|
self.active_dims = np.where(mask)[0]
|
||||||
|
|
||||||
|
if isinstance(env.action_space, gym.spaces.Box):
|
||||||
|
if len(mask) != env.action_space.shape[0]:
|
||||||
|
raise ValueError("Mask length must match action space dimensions")
|
||||||
|
low = env.action_space.low[self.active_dims]
|
||||||
|
high = env.action_space.high[self.active_dims]
|
||||||
|
self.action_space = gym.spaces.Box(low=low, high=high, dtype=env.action_space.dtype)
|
||||||
|
|
||||||
|
if isinstance(env.action_space, gym.spaces.Tuple):
|
||||||
|
if len(mask) != env.action_space[0].shape[0]:
|
||||||
|
raise ValueError("Mask length must match action space 0 dimensions")
|
||||||
|
|
||||||
|
low = env.action_space[0].low[self.active_dims]
|
||||||
|
high = env.action_space[0].high[self.active_dims]
|
||||||
|
action_space_masked = gym.spaces.Box(low=low, high=high, dtype=env.action_space[0].dtype)
|
||||||
|
self.action_space = gym.spaces.Tuple((action_space_masked, env.action_space[1]))
|
||||||
|
# Create new action space with masked dimensions
|
||||||
|
|
||||||
|
def action(self, action):
|
||||||
|
"""
|
||||||
|
Convert masked action back to full action space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: Action in masked space. For Tuple spaces, the first element is masked.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Action in original space with masked dims set to 0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Determine whether we are handling a Tuple space or a Box.
|
||||||
|
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
||||||
|
# Extract the masked component from the tuple.
|
||||||
|
masked_action = action[0] if isinstance(action, tuple) else action
|
||||||
|
# Create a full action for the Box element.
|
||||||
|
full_box_action = np.zeros(self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype)
|
||||||
|
full_box_action[self.active_dims] = masked_action
|
||||||
|
# Return a tuple with the reconstructed Box action and the unchanged remainder.
|
||||||
|
return (full_box_action, action[1])
|
||||||
|
else:
|
||||||
|
# For Box action spaces.
|
||||||
|
masked_action = action if not isinstance(action, tuple) else action[0]
|
||||||
|
full_action = np.zeros(self.env.action_space.shape, dtype=self.env.action_space.dtype)
|
||||||
|
full_action[self.active_dims] = masked_action
|
||||||
|
return full_action
|
||||||
|
|
||||||
|
|
||||||
class TimeLimitWrapper(gym.Wrapper):
|
class TimeLimitWrapper(gym.Wrapper):
|
||||||
def __init__(self, env, control_time_s, fps):
|
def __init__(self, env, control_time_s, fps):
|
||||||
self.env = env
|
self.env = env
|
||||||
|
@ -331,13 +400,12 @@ class TimeLimitWrapper(gym.Wrapper):
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
time_since_last_step = time.perf_counter() - self.last_timestamp
|
time_since_last_step = time.perf_counter() - self.last_timestamp
|
||||||
# logging.warning(f"Current timestep is lower than the expected fps {self.fps}")
|
|
||||||
self.episode_time_in_s += time_since_last_step
|
self.episode_time_in_s += time_since_last_step
|
||||||
self.last_timestamp = time.perf_counter()
|
self.last_timestamp = time.perf_counter()
|
||||||
self.current_step += 1
|
self.current_step += 1
|
||||||
# check if last timestep took more time than the expected fps
|
# check if last timestep took more time than the expected fps
|
||||||
# if 1.0 / time_since_last_step < self.fps:
|
if 1.0 / time_since_last_step < self.fps:
|
||||||
# logging.warning(f"Current timestep exceeded expected fps {self.fps}")
|
logging.debug(f"Current timestep exceeded expected fps {self.fps}")
|
||||||
|
|
||||||
if self.episode_time_in_s > self.control_time_s:
|
if self.episode_time_in_s > self.control_time_s:
|
||||||
# if self.current_step >= self.max_episode_steps:
|
# if self.current_step >= self.max_episode_steps:
|
||||||
|
@ -360,7 +428,7 @@ class ImageCropResizeWrapper(gym.Wrapper):
|
||||||
print(f"obs_keys , {self.env.observation_space}")
|
print(f"obs_keys , {self.env.observation_space}")
|
||||||
print(f"crop params dict {crop_params_dict.keys()}")
|
print(f"crop params dict {crop_params_dict.keys()}")
|
||||||
for key_crop in crop_params_dict:
|
for key_crop in crop_params_dict:
|
||||||
if key_crop not in self.env.observation_space.keys():
|
if key_crop not in self.env.observation_space.keys(): # noqa: SIM118
|
||||||
raise ValueError(f"Key {key_crop} not in observation space")
|
raise ValueError(f"Key {key_crop} not in observation space")
|
||||||
for key in crop_params_dict:
|
for key in crop_params_dict:
|
||||||
top, left, height, width = crop_params_dict[key]
|
top, left, height, width = crop_params_dict[key]
|
||||||
|
@ -375,14 +443,23 @@ class ImageCropResizeWrapper(gym.Wrapper):
|
||||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
for k in self.crop_params_dict:
|
for k in self.crop_params_dict:
|
||||||
device = obs[k].device
|
device = obs[k].device
|
||||||
|
|
||||||
|
# Check for NaNs before processing
|
||||||
|
if torch.isnan(obs[k]).any():
|
||||||
|
logging.error(f"NaN values detected in observation {k} before crop and resize")
|
||||||
|
|
||||||
if device == torch.device("mps:0"):
|
if device == torch.device("mps:0"):
|
||||||
obs[k] = obs[k].cpu()
|
obs[k] = obs[k].cpu()
|
||||||
|
|
||||||
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
|
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
|
||||||
obs[k] = F.resize(obs[k], self.resize_size)
|
obs[k] = F.resize(obs[k], self.resize_size)
|
||||||
|
|
||||||
|
# Check for NaNs after processing
|
||||||
|
if torch.isnan(obs[k]).any():
|
||||||
|
logging.error(f"NaN values detected in observation {k} after crop and resize")
|
||||||
|
|
||||||
obs[k] = obs[k].to(device)
|
obs[k] = obs[k].to(device)
|
||||||
# print(f"observation with key {k} with size {obs[k].size()}")
|
|
||||||
# cv2.imshow(k, cv2.cvtColor(obs[k].cpu().squeeze(0).permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR))
|
|
||||||
# cv2.waitKey(1)
|
|
||||||
return obs, reward, terminated, truncated, info
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
def reset(self, seed=None, options=None):
|
def reset(self, seed=None, options=None):
|
||||||
|
@ -400,12 +477,18 @@ class ImageCropResizeWrapper(gym.Wrapper):
|
||||||
class ConvertToLeRobotObservation(gym.ObservationWrapper):
|
class ConvertToLeRobotObservation(gym.ObservationWrapper):
|
||||||
def __init__(self, env, device):
|
def __init__(self, env, device):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
||||||
|
if isinstance(device, str):
|
||||||
|
device = torch.device(device)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
observation = preprocess_observation(observation)
|
observation = preprocess_observation(observation)
|
||||||
|
|
||||||
observation = {key: observation[key].to(self.device, non_blocking=True) for key in observation}
|
observation = {
|
||||||
|
key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
||||||
|
for key in observation
|
||||||
|
}
|
||||||
observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()}
|
observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()}
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
@ -440,7 +523,7 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||||
if key == keyboard.Key.right or key == keyboard.Key.esc:
|
if key == keyboard.Key.right or key == keyboard.Key.esc:
|
||||||
print("Right arrow key pressed. Exiting loop...")
|
print("Right arrow key pressed. Exiting loop...")
|
||||||
self.events["exit_early"] = True
|
self.events["exit_early"] = True
|
||||||
elif key == keyboard.Key.space:
|
elif key == keyboard.Key.space and not self.events["exit_early"]:
|
||||||
if not self.events["pause_policy"]:
|
if not self.events["pause_policy"]:
|
||||||
print(
|
print(
|
||||||
"Space key pressed. Human intervention required.\n"
|
"Space key pressed. Human intervention required.\n"
|
||||||
|
@ -587,6 +670,7 @@ def make_robot_env(
|
||||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||||
env = KeyboardInterfaceWrapper(env=env)
|
env = KeyboardInterfaceWrapper(env=env)
|
||||||
env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.wrapper.reset_time_s)
|
env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.wrapper.reset_time_s)
|
||||||
|
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
||||||
env = BatchCompitableWrapper(env=env)
|
env = BatchCompitableWrapper(env=env)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
@ -596,7 +680,7 @@ def make_robot_env(
|
||||||
|
|
||||||
def get_classifier(pretrained_path, config_path, device="mps"):
|
def get_classifier(pretrained_path, config_path, device="mps"):
|
||||||
if pretrained_path is None or config_path is None:
|
if pretrained_path is None or config_path is None:
|
||||||
return
|
return None
|
||||||
|
|
||||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||||
|
|
|
@ -278,12 +278,23 @@ def learner_push_parameters(
|
||||||
torch.save(params_dict, buf)
|
torch.save(params_dict, buf)
|
||||||
params_bytes = buf.getvalue()
|
params_bytes = buf.getvalue()
|
||||||
|
|
||||||
# Push them to the Actor’s "SendParameters" method
|
# Push them to the Actor's "SendParameters" method
|
||||||
logging.info("[LEARNER] Publishing parameters to the Actor")
|
logging.info("[LEARNER] Publishing parameters to the Actor")
|
||||||
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) # noqa: F841
|
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) # noqa: F841
|
||||||
time.sleep(seconds_between_pushes)
|
time.sleep(seconds_between_pushes)
|
||||||
|
|
||||||
|
|
||||||
|
def check_nan_in_transition(observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor):
|
||||||
|
for k in observations:
|
||||||
|
if torch.isnan(observations[k]).any():
|
||||||
|
logging.error(f"observations[{k}] contains NaN values")
|
||||||
|
for k in next_state:
|
||||||
|
if torch.isnan(next_state[k]).any():
|
||||||
|
logging.error(f"next_state[{k}] contains NaN values")
|
||||||
|
if torch.isnan(actions).any():
|
||||||
|
logging.error("actions contains NaN values")
|
||||||
|
|
||||||
|
|
||||||
def add_actor_information_and_train(
|
def add_actor_information_and_train(
|
||||||
cfg,
|
cfg,
|
||||||
device: str,
|
device: str,
|
||||||
|
@ -372,6 +383,7 @@ def add_actor_information_and_train(
|
||||||
observations = batch["state"]
|
observations = batch["state"]
|
||||||
next_observations = batch["next_state"]
|
next_observations = batch["next_state"]
|
||||||
done = batch["done"]
|
done = batch["done"]
|
||||||
|
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||||
|
|
||||||
with policy_lock:
|
with policy_lock:
|
||||||
loss_critic = policy.compute_loss_critic(
|
loss_critic = policy.compute_loss_critic(
|
||||||
|
@ -399,6 +411,8 @@ def add_actor_information_and_train(
|
||||||
next_observations = batch["next_state"]
|
next_observations = batch["next_state"]
|
||||||
done = batch["done"]
|
done = batch["done"]
|
||||||
|
|
||||||
|
assert_and_breakpoint(observations=observations, actions=actions, next_state=next_observations)
|
||||||
|
|
||||||
with policy_lock:
|
with policy_lock:
|
||||||
loss_critic = policy.compute_loss_critic(
|
loss_critic = policy.compute_loss_critic(
|
||||||
observations=observations,
|
observations=observations,
|
||||||
|
@ -497,8 +511,8 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||||
It also initializes a learning rate scheduler, though currently, it is set to `None`.
|
It also initializes a learning rate scheduler, though currently, it is set to `None`.
|
||||||
|
|
||||||
**NOTE:**
|
**NOTE:**
|
||||||
- If the encoder is shared, its parameters are excluded from the actor’s optimization process.
|
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
|
||||||
- The policy’s log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
|
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Configuration object containing hyperparameters.
|
cfg: Configuration object containing hyperparameters.
|
||||||
|
|
Loading…
Reference in New Issue