mirror of https://github.com/fan-ziqi/rl_sar.git
fix: simplified handling of historical observation parameters, ensuring initialization of the observation history buffer in non-empty cases.
This commit is contained in:
parent
8c397b3c7a
commit
dc8353a9ed
|
@ -390,11 +390,7 @@ class RL:
|
||||||
self.params.decimation = config["decimation"]
|
self.params.decimation = config["decimation"]
|
||||||
self.params.num_observations = config["num_observations"]
|
self.params.num_observations = config["num_observations"]
|
||||||
self.params.observations = config["observations"]
|
self.params.observations = config["observations"]
|
||||||
if config["observations_history"] is None:
|
|
||||||
self.params.observations_history = None
|
|
||||||
else:
|
|
||||||
self.params.observations_history = config["observations_history"]
|
self.params.observations_history = config["observations_history"]
|
||||||
|
|
||||||
self.params.clip_obs = config["clip_obs"]
|
self.params.clip_obs = config["clip_obs"]
|
||||||
self.params.action_scale = config["action_scale"]
|
self.params.action_scale = config["action_scale"]
|
||||||
self.params.hip_scale_reduction = config["hip_scale_reduction"]
|
self.params.hip_scale_reduction = config["hip_scale_reduction"]
|
||||||
|
|
|
@ -39,7 +39,7 @@ class RL_Sim(RL):
|
||||||
self.params.observations[i] = "ang_vel_world"
|
self.params.observations[i] = "ang_vel_world"
|
||||||
|
|
||||||
# history
|
# history
|
||||||
if self.params.observations_history is None:
|
if len(self.params.observations_history) != 0:
|
||||||
self.history_obs_buf = ObservationBuffer(1, self.params.num_observations, len(self.params.observations_history))
|
self.history_obs_buf = ObservationBuffer(1, self.params.num_observations, len(self.params.observations_history))
|
||||||
|
|
||||||
# Due to the fact that the robot_state_publisher sorts the joint names alphabetically,
|
# Due to the fact that the robot_state_publisher sorts the joint names alphabetically,
|
||||||
|
@ -205,7 +205,7 @@ class RL_Sim(RL):
|
||||||
def Forward(self):
|
def Forward(self):
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
clamped_obs = self.ComputeObservation()
|
clamped_obs = self.ComputeObservation()
|
||||||
if self.params.observations_history is None:
|
if len(self.params.observations_history) != 0:
|
||||||
self.history_obs_buf.insert(clamped_obs)
|
self.history_obs_buf.insert(clamped_obs)
|
||||||
history_obs = self.history_obs_buf.get_obs_vec(self.params.observations_history)
|
history_obs = self.history_obs_buf.get_obs_vec(self.params.observations_history)
|
||||||
actions = self.model.forward(history_obs)
|
actions = self.model.forward(history_obs)
|
||||||
|
|
|
@ -20,7 +20,7 @@ RL_Sim::RL_Sim()
|
||||||
}
|
}
|
||||||
|
|
||||||
// history
|
// history
|
||||||
if (!this->params.observations_history.empty())
|
if (this->params.observations_history.size() != 0)
|
||||||
{
|
{
|
||||||
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, this->params.observations_history.size());
|
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, this->params.observations_history.size());
|
||||||
}
|
}
|
||||||
|
@ -259,7 +259,7 @@ torch::Tensor RL_Sim::Forward()
|
||||||
torch::Tensor clamped_obs = this->ComputeObservation();
|
torch::Tensor clamped_obs = this->ComputeObservation();
|
||||||
|
|
||||||
torch::Tensor actions;
|
torch::Tensor actions;
|
||||||
if (!this->params.observations_history.empty())
|
if (this->params.observations_history.size() != 0)
|
||||||
{
|
{
|
||||||
this->history_obs_buf.insert(clamped_obs);
|
this->history_obs_buf.insert(clamped_obs);
|
||||||
this->history_obs = this->history_obs_buf.get_obs_vec(this->params.observations_history);
|
this->history_obs = this->history_obs_buf.get_obs_vec(this->params.observations_history);
|
||||||
|
|
Loading…
Reference in New Issue