Sign issue in modeling sac

This commit is contained in:
Michel Aractingi 2025-04-08 11:05:10 +02:00
parent 06de182448
commit 10adadbc71
3 changed files with 4 additions and 3 deletions

View File

@ -202,7 +202,7 @@ class EnvWrapperConfig:
ee_action_space_params: Optional[EEActionSpaceConfig] = None
use_gripper: bool = False
gripper_quantization_threshold: float | None = None
gripper_penalty: float = 0.0
gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
open_gripper_on_reset: bool = False

View File

@ -79,7 +79,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
elif isinstance(val, (int, float, bool)):
transition["complementary_info"][key] = torch.tensor(
val, device=device, non_blocking=non_blocking
val, device=device
)
else:
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")

View File

@ -1195,7 +1195,8 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
env = GripperActionWrapper(
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
)
env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward)
if cfg.wrapper.gripper_penalty is not None:
env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward)
if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper(