Sign issue in modeling sac

This commit is contained in:
Michel Aractingi 2025-04-08 11:05:10 +02:00
parent 06de182448
commit 10adadbc71
3 changed files with 4 additions and 3 deletions

View File

@ -79,7 +79,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking) transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
elif isinstance(val, (int, float, bool)): elif isinstance(val, (int, float, bool)):
transition["complementary_info"][key] = torch.tensor( transition["complementary_info"][key] = torch.tensor(
val, device=device, non_blocking=non_blocking val, device=device
) )
else: else:
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]") raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")

View File

@ -1195,6 +1195,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
env = GripperActionWrapper( env = GripperActionWrapper(
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
) )
if cfg.wrapper.gripper_penalty is not None:
env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward) env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward)
if cfg.wrapper.ee_action_space_params is not None: if cfg.wrapper.ee_action_space_params is not None: