Sign issue in modeling sac
This commit is contained in:
parent
06de182448
commit
10adadbc71
|
@ -79,7 +79,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
|
||||||
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
|
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
|
||||||
elif isinstance(val, (int, float, bool)):
|
elif isinstance(val, (int, float, bool)):
|
||||||
transition["complementary_info"][key] = torch.tensor(
|
transition["complementary_info"][key] = torch.tensor(
|
||||||
val, device=device, non_blocking=non_blocking
|
val, device=device
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
|
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
|
||||||
|
|
|
@ -1195,6 +1195,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||||
env = GripperActionWrapper(
|
env = GripperActionWrapper(
|
||||||
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
||||||
)
|
)
|
||||||
|
if cfg.wrapper.gripper_penalty is not None:
|
||||||
env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward)
|
env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward)
|
||||||
|
|
||||||
if cfg.wrapper.ee_action_space_params is not None:
|
if cfg.wrapper.ee_action_space_params is not None:
|
||||||
|
|
Loading…
Reference in New Issue