From dc8353a9ed569d16ec41405dc2fa1fbb65c63637 Mon Sep 17 00:00:00 2001 From: fan-ziqi Date: Sun, 13 Oct 2024 00:17:02 +0800 Subject: [PATCH] fix: simplified handling of historical observation parameters, ensuring initialization of the observation history buffer in non-empty cases. --- src/rl_sar/scripts/rl_sdk.py | 6 +----- src/rl_sar/scripts/rl_sim.py | 4 ++-- src/rl_sar/src/rl_sim.cpp | 4 ++-- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/rl_sar/scripts/rl_sdk.py b/src/rl_sar/scripts/rl_sdk.py index 6c2a642..2761b30 100644 --- a/src/rl_sar/scripts/rl_sdk.py +++ b/src/rl_sar/scripts/rl_sdk.py @@ -390,11 +390,7 @@ class RL: self.params.decimation = config["decimation"] self.params.num_observations = config["num_observations"] self.params.observations = config["observations"] - if config["observations_history"] is None: - self.params.observations_history = None - else: - self.params.observations_history = config["observations_history"] - + self.params.observations_history = config["observations_history"] self.params.clip_obs = config["clip_obs"] self.params.action_scale = config["action_scale"] self.params.hip_scale_reduction = config["hip_scale_reduction"] diff --git a/src/rl_sar/scripts/rl_sim.py b/src/rl_sar/scripts/rl_sim.py index bcc4192..fc88643 100644 --- a/src/rl_sar/scripts/rl_sim.py +++ b/src/rl_sar/scripts/rl_sim.py @@ -39,7 +39,7 @@ class RL_Sim(RL): self.params.observations[i] = "ang_vel_world" # history - if self.params.observations_history is None: + if len(self.params.observations_history) != 0: self.history_obs_buf = ObservationBuffer(1, self.params.num_observations, len(self.params.observations_history)) # Due to the fact that the robot_state_publisher sorts the joint names alphabetically, @@ -205,7 +205,7 @@ class RL_Sim(RL): def Forward(self): torch.set_grad_enabled(False) clamped_obs = self.ComputeObservation() - if self.params.observations_history is None: + if len(self.params.observations_history) != 0: self.history_obs_buf.insert(clamped_obs) history_obs = self.history_obs_buf.get_obs_vec(self.params.observations_history) actions = self.model.forward(history_obs) diff --git a/src/rl_sar/src/rl_sim.cpp b/src/rl_sar/src/rl_sim.cpp index 6242fdb..7f93df9 100644 --- a/src/rl_sar/src/rl_sim.cpp +++ b/src/rl_sar/src/rl_sim.cpp @@ -20,7 +20,7 @@ RL_Sim::RL_Sim() } // history - if (!this->params.observations_history.empty()) + if (this->params.observations_history.size() != 0) { this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, this->params.observations_history.size()); } @@ -259,7 +259,7 @@ torch::Tensor RL_Sim::Forward() torch::Tensor clamped_obs = this->ComputeObservation(); torch::Tensor actions; - if (!this->params.observations_history.empty()) + if (this->params.observations_history.size() != 0) { this->history_obs_buf.insert(clamped_obs); this->history_obs = this->history_obs_buf.get_obs_vec(this->params.observations_history);