Changed the init_final value to center the starting mean and std of the policy
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
fcb9d3008b
commit
caf55994eb
|
@ -148,7 +148,7 @@ class Classifier(
|
|||
def predict_reward(self, x, threshold=0.6):
|
||||
if self.config.num_classes == 2:
|
||||
probs = self.forward(x).probabilities
|
||||
logging.info(f"Predicted reward images: {probs}")
|
||||
logging.debug(f"Predicted reward images: {probs}")
|
||||
return (probs > threshold).float()
|
||||
else:
|
||||
return torch.argmax(self.forward(x).probabilities, dim=1)
|
||||
|
|
|
@ -95,5 +95,6 @@ class SACConfig:
|
|||
"use_tanh_squash": True,
|
||||
"log_std_min": -5,
|
||||
"log_std_max": 2,
|
||||
"init_final": 0.01,
|
||||
}
|
||||
)
|
||||
|
|
|
@ -327,7 +327,7 @@ def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int
|
|||
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
||||
stats = {}
|
||||
list_policy_fps = [1.0 / t for t in list_policy_time]
|
||||
if len(list_policy_fps) > 0:
|
||||
if len(list_policy_fps) > 1:
|
||||
policy_fps = mean(list_policy_fps)
|
||||
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
|
||||
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
|
||||
|
|
|
@ -217,7 +217,7 @@ class HILSerlRobotEnv(gym.Env):
|
|||
if torch.any(teleop_action < -self.delta_relative_bounds_size * self.delta) and torch.any(
|
||||
teleop_action > self.delta_relative_bounds_size
|
||||
):
|
||||
print(
|
||||
logging.debug(
|
||||
f"Relative teleop delta exceeded bounds {self.delta_relative_bounds_size}, teleop_action {teleop_action}\n"
|
||||
f"lower bounds condition {teleop_action < -self.delta_relative_bounds_size}\n"
|
||||
f"upper bounds condition {teleop_action > self.delta_relative_bounds_size}"
|
||||
|
@ -318,7 +318,7 @@ class RewardWrapper(gym.Wrapper):
|
|||
)
|
||||
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
|
||||
|
||||
logging.info(f"Reward: {reward}")
|
||||
# logging.info(f"Reward: {reward}")
|
||||
|
||||
if reward == 1.0:
|
||||
terminated = True
|
||||
|
|
Loading…
Reference in New Issue