fix: simplified handling of historical observation parameters, ensuring initialization of the observation history buffer in non-empty cases.

This commit is contained in:
fan-ziqi 2024-10-13 00:17:02 +08:00
parent 8c397b3c7a
commit dc8353a9ed
3 changed files with 5 additions and 9 deletions

View File

@ -390,11 +390,7 @@ class RL:
self.params.decimation = config["decimation"] self.params.decimation = config["decimation"]
self.params.num_observations = config["num_observations"] self.params.num_observations = config["num_observations"]
self.params.observations = config["observations"] self.params.observations = config["observations"]
if config["observations_history"] is None:
self.params.observations_history = None
else:
self.params.observations_history = config["observations_history"] self.params.observations_history = config["observations_history"]
self.params.clip_obs = config["clip_obs"] self.params.clip_obs = config["clip_obs"]
self.params.action_scale = config["action_scale"] self.params.action_scale = config["action_scale"]
self.params.hip_scale_reduction = config["hip_scale_reduction"] self.params.hip_scale_reduction = config["hip_scale_reduction"]

View File

@ -39,7 +39,7 @@ class RL_Sim(RL):
self.params.observations[i] = "ang_vel_world" self.params.observations[i] = "ang_vel_world"
# history # history
if self.params.observations_history is None: if len(self.params.observations_history) != 0:
self.history_obs_buf = ObservationBuffer(1, self.params.num_observations, len(self.params.observations_history)) self.history_obs_buf = ObservationBuffer(1, self.params.num_observations, len(self.params.observations_history))
# Due to the fact that the robot_state_publisher sorts the joint names alphabetically, # Due to the fact that the robot_state_publisher sorts the joint names alphabetically,
@ -205,7 +205,7 @@ class RL_Sim(RL):
def Forward(self): def Forward(self):
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
clamped_obs = self.ComputeObservation() clamped_obs = self.ComputeObservation()
if self.params.observations_history is None: if len(self.params.observations_history) != 0:
self.history_obs_buf.insert(clamped_obs) self.history_obs_buf.insert(clamped_obs)
history_obs = self.history_obs_buf.get_obs_vec(self.params.observations_history) history_obs = self.history_obs_buf.get_obs_vec(self.params.observations_history)
actions = self.model.forward(history_obs) actions = self.model.forward(history_obs)

View File

@ -20,7 +20,7 @@ RL_Sim::RL_Sim()
} }
// history // history
if (!this->params.observations_history.empty()) if (this->params.observations_history.size() != 0)
{ {
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, this->params.observations_history.size()); this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, this->params.observations_history.size());
} }
@ -259,7 +259,7 @@ torch::Tensor RL_Sim::Forward()
torch::Tensor clamped_obs = this->ComputeObservation(); torch::Tensor clamped_obs = this->ComputeObservation();
torch::Tensor actions; torch::Tensor actions;
if (!this->params.observations_history.empty()) if (this->params.observations_history.size() != 0)
{ {
this->history_obs_buf.insert(clamped_obs); this->history_obs_buf.insert(clamped_obs);
this->history_obs = this->history_obs_buf.get_obs_vec(this->params.observations_history); this->history_obs = this->history_obs_buf.get_obs_vec(this->params.observations_history);