mirror of https://github.com/fan-ziqi/rl_sar.git
fix: add USE_HISTORY
This commit is contained in:
parent
13dbb895bf
commit
87d9b697e5
|
@ -4,6 +4,7 @@
|
|||
|
||||
// #define PLOT
|
||||
// #define CSV_LOGGER
|
||||
#define USE_HISTORY
|
||||
|
||||
RL_Sim::RL_Sim()
|
||||
{
|
||||
|
@ -33,7 +34,9 @@ RL_Sim::RL_Sim()
|
|||
this->InitObservations();
|
||||
this->InitOutputs();
|
||||
|
||||
#ifdef USE_HISTORY
|
||||
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
||||
#endif
|
||||
|
||||
joint_positions = std::vector<double>(params.num_of_dofs, 0.0);
|
||||
joint_velocities = std::vector<double>(params.num_of_dofs, 0.0);
|
||||
|
@ -175,10 +178,13 @@ torch::Tensor RL_Sim::Forward()
|
|||
{
|
||||
torch::Tensor obs = this->ComputeObservation();
|
||||
|
||||
#ifdef USE_HISTORY
|
||||
history_obs_buf.insert(obs);
|
||||
history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
|
||||
|
||||
torch::Tensor action = this->model.forward({history_obs}).toTensor();
|
||||
#else
|
||||
torch::Tensor action = this->model.forward({obs}).toTensor();
|
||||
#endif
|
||||
|
||||
this->obs.actions = action;
|
||||
torch::Tensor clamped = torch::clamp(action, -this->params.clip_actions, this->params.clip_actions);
|
||||
|
|
Loading…
Reference in New Issue