mirror of https://github.com/fan-ziqi/rl_sar.git
fix: simplified handling of historical observation parameters, ensuring initialization of the observation history buffer in non-empty cases.
This commit is contained in:
parent
8c397b3c7a
commit
dc8353a9ed
|
@ -390,11 +390,7 @@ class RL:
|
|||
self.params.decimation = config["decimation"]
|
||||
self.params.num_observations = config["num_observations"]
|
||||
self.params.observations = config["observations"]
|
||||
if config["observations_history"] is None:
|
||||
self.params.observations_history = None
|
||||
else:
|
||||
self.params.observations_history = config["observations_history"]
|
||||
|
||||
self.params.observations_history = config["observations_history"]
|
||||
self.params.clip_obs = config["clip_obs"]
|
||||
self.params.action_scale = config["action_scale"]
|
||||
self.params.hip_scale_reduction = config["hip_scale_reduction"]
|
||||
|
|
|
@ -39,7 +39,7 @@ class RL_Sim(RL):
|
|||
self.params.observations[i] = "ang_vel_world"
|
||||
|
||||
# history
|
||||
if self.params.observations_history is None:
|
||||
if len(self.params.observations_history) != 0:
|
||||
self.history_obs_buf = ObservationBuffer(1, self.params.num_observations, len(self.params.observations_history))
|
||||
|
||||
# Due to the fact that the robot_state_publisher sorts the joint names alphabetically,
|
||||
|
@ -205,7 +205,7 @@ class RL_Sim(RL):
|
|||
def Forward(self):
|
||||
torch.set_grad_enabled(False)
|
||||
clamped_obs = self.ComputeObservation()
|
||||
if self.params.observations_history is None:
|
||||
if len(self.params.observations_history) != 0:
|
||||
self.history_obs_buf.insert(clamped_obs)
|
||||
history_obs = self.history_obs_buf.get_obs_vec(self.params.observations_history)
|
||||
actions = self.model.forward(history_obs)
|
||||
|
|
|
@ -20,7 +20,7 @@ RL_Sim::RL_Sim()
|
|||
}
|
||||
|
||||
// history
|
||||
if (!this->params.observations_history.empty())
|
||||
if (this->params.observations_history.size() != 0)
|
||||
{
|
||||
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, this->params.observations_history.size());
|
||||
}
|
||||
|
@ -259,7 +259,7 @@ torch::Tensor RL_Sim::Forward()
|
|||
torch::Tensor clamped_obs = this->ComputeObservation();
|
||||
|
||||
torch::Tensor actions;
|
||||
if (!this->params.observations_history.empty())
|
||||
if (this->params.observations_history.size() != 0)
|
||||
{
|
||||
this->history_obs_buf.insert(clamped_obs);
|
||||
this->history_obs = this->history_obs_buf.get_obs_vec(this->params.observations_history);
|
||||
|
|
Loading…
Reference in New Issue