diff --git a/src/rl_sar/src/rl_sim.cpp b/src/rl_sar/src/rl_sim.cpp index 286bcef..0225bbb 100644 --- a/src/rl_sar/src/rl_sim.cpp +++ b/src/rl_sar/src/rl_sim.cpp @@ -4,6 +4,7 @@ // #define PLOT // #define CSV_LOGGER +#define USE_HISTORY RL_Sim::RL_Sim() { @@ -33,7 +34,9 @@ RL_Sim::RL_Sim() this->InitObservations(); this->InitOutputs(); +#ifdef USE_HISTORY this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); +#endif joint_positions = std::vector(params.num_of_dofs, 0.0); joint_velocities = std::vector(params.num_of_dofs, 0.0); @@ -175,10 +178,13 @@ torch::Tensor RL_Sim::Forward() { torch::Tensor obs = this->ComputeObservation(); +#ifdef USE_HISTORY history_obs_buf.insert(obs); history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5}); - torch::Tensor action = this->model.forward({history_obs}).toTensor(); +#else + torch::Tensor action = this->model.forward({obs}).toTensor(); +#endif this->obs.actions = action; torch::Tensor clamped = torch::clamp(action, -this->params.clip_actions, this->params.clip_actions);