From 05c38eedbae602e120c69c7c83209320e411b4bf Mon Sep 17 00:00:00 2001 From: fan-ziqi Date: Mon, 3 Mar 2025 16:38:06 +0800 Subject: [PATCH] fix history buffer --- .../core/observation_buffer/observation_buffer.cpp | 10 +++++++++- src/rl_sar/models/a1/legged_gym/config.yaml | 2 +- src/rl_sar/models/a1/robot_lab/config.yaml | 2 +- src/rl_sar/models/b2/robot_lab/config.yaml | 2 +- src/rl_sar/models/b2w/robot_lab/config.yaml | 2 +- src/rl_sar/models/go2/himloco/config.yaml | 2 +- src/rl_sar/models/go2/robot_lab/config.yaml | 2 +- src/rl_sar/models/go2w/robot_lab/config.yaml | 2 +- src/rl_sar/models/gr1t1/legged_gym/config.yaml | 2 +- src/rl_sar/models/gr1t2/legged_gym/config.yaml | 2 +- src/rl_sar/models/l4w4/legged_gym/config.yaml | 2 +- src/rl_sar/scripts/observation_buffer.py | 6 +++--- 12 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/rl_sar/library/core/observation_buffer/observation_buffer.cpp b/src/rl_sar/library/core/observation_buffer/observation_buffer.cpp index 96c27d2..7557af1 100644 --- a/src/rl_sar/library/core/observation_buffer/observation_buffer.cpp +++ b/src/rl_sar/library/core/observation_buffer/observation_buffer.cpp @@ -38,10 +38,18 @@ void ObservationBuffer::insert(torch::Tensor new_obs) obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(-num_obs, torch::indexing::None)}) = new_obs; } +/** + * @brief Gets history of observations indexed by obs_ids. + * + * @param obs_ids An array of integers with which to index the desired + * observations, where 0 is the latest observation and + * include_history_steps - 1 is the oldest observation. + * @return A torch::Tensor containing the concatenated observations. + */ torch::Tensor ObservationBuffer::get_obs_vec(std::vector obs_ids) { std::vector obs; - for (int i = obs_ids.size() - 1; i >= 0; --i) + for (int i = 0; i < obs_ids.size(); ++i) { int obs_id = obs_ids[i]; int slice_idx = include_history_steps - obs_id - 1; diff --git a/src/rl_sar/models/a1/legged_gym/config.yaml b/src/rl_sar/models/a1/legged_gym/config.yaml index 23f89eb..62fa06d 100644 --- a/src/rl_sar/models/a1/legged_gym/config.yaml +++ b/src/rl_sar/models/a1/legged_gym/config.yaml @@ -8,7 +8,7 @@ a1/legged_gym: decimation: 4 num_observations: 45 observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"] - observations_history: [0, 1, 2, 3, 4, 5] + observations_history: [5, 4, 3, 2, 1, 0] # 0 is the latest observation clip_obs: 100.0 clip_actions_lower: [-100.0, -100.0, -100.0, -100.0, -100.0, -100.0, diff --git a/src/rl_sar/models/a1/robot_lab/config.yaml b/src/rl_sar/models/a1/robot_lab/config.yaml index f58acab..67b406b 100644 --- a/src/rl_sar/models/a1/robot_lab/config.yaml +++ b/src/rl_sar/models/a1/robot_lab/config.yaml @@ -8,7 +8,7 @@ a1/robot_lab: decimation: 4 num_observations: 45 observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"] - observations_history: [] + observations_history: [] # 0 is the latest observation clip_obs: 100.0 clip_actions_lower: [-100.0, -100.0, -100.0, -100.0, -100.0, -100.0, diff --git a/src/rl_sar/models/b2/robot_lab/config.yaml b/src/rl_sar/models/b2/robot_lab/config.yaml index b630f93..37b8b39 100644 --- a/src/rl_sar/models/b2/robot_lab/config.yaml +++ b/src/rl_sar/models/b2/robot_lab/config.yaml @@ -8,7 +8,7 @@ b2/robot_lab: decimation: 4 num_observations: 45 observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"] - observations_history: [] + observations_history: [] # 0 is the latest observation clip_obs: 100.0 clip_actions_lower: [-100.0, -100.0, -100.0, -100.0, -100.0, -100.0, diff --git a/src/rl_sar/models/b2w/robot_lab/config.yaml b/src/rl_sar/models/b2w/robot_lab/config.yaml index cc0ef7f..318e5b5 100644 --- a/src/rl_sar/models/b2w/robot_lab/config.yaml +++ b/src/rl_sar/models/b2w/robot_lab/config.yaml @@ -5,7 +5,7 @@ b2w/robot_lab: decimation: 4 num_observations: 57 observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"] - observations_history: [] + observations_history: [] # 0 is the latest observation clip_obs: 100.0 clip_actions_lower: [-100.0, -100.0, -100.0, -100.0, -100.0, -100.0, diff --git a/src/rl_sar/models/go2/himloco/config.yaml b/src/rl_sar/models/go2/himloco/config.yaml index 3b54a91..99e56d4 100644 --- a/src/rl_sar/models/go2/himloco/config.yaml +++ b/src/rl_sar/models/go2/himloco/config.yaml @@ -8,7 +8,7 @@ go2/himloco: decimation: 4 num_observations: 45 observations: ["commands", "ang_vel", "gravity_vec", "dof_pos", "dof_vel", "actions"] - observations_history: [5, 4, 3, 2, 1, 0] + observations_history: [0, 1, 2, 3, 4, 5] # 0 is the latest observation clip_obs: 100.0 clip_actions_lower: [-100, -100, -100, -100, -100, -100, diff --git a/src/rl_sar/models/go2/robot_lab/config.yaml b/src/rl_sar/models/go2/robot_lab/config.yaml index c1cec0d..a1c0dff 100644 --- a/src/rl_sar/models/go2/robot_lab/config.yaml +++ b/src/rl_sar/models/go2/robot_lab/config.yaml @@ -8,7 +8,7 @@ go2/robot_lab: decimation: 4 num_observations: 45 observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"] - observations_history: [] + observations_history: [] # 0 is the latest observation clip_obs: 100.0 clip_actions_lower: [-100.0, -100.0, -100.0, -100.0, -100.0, -100.0, diff --git a/src/rl_sar/models/go2w/robot_lab/config.yaml b/src/rl_sar/models/go2w/robot_lab/config.yaml index 14c32a2..379b71d 100644 --- a/src/rl_sar/models/go2w/robot_lab/config.yaml +++ b/src/rl_sar/models/go2w/robot_lab/config.yaml @@ -5,7 +5,7 @@ go2w/robot_lab: decimation: 4 num_observations: 57 observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"] - observations_history: [] + observations_history: [] # 0 is the latest observation clip_obs: 100.0 clip_actions_lower: [-100.0, -100.0, -100.0, -100.0, -100.0, -100.0, diff --git a/src/rl_sar/models/gr1t1/legged_gym/config.yaml b/src/rl_sar/models/gr1t1/legged_gym/config.yaml index e374f25..7ef7d6f 100644 --- a/src/rl_sar/models/gr1t1/legged_gym/config.yaml +++ b/src/rl_sar/models/gr1t1/legged_gym/config.yaml @@ -8,7 +8,7 @@ gr1t1/legged_gym: decimation: 20 num_observations: 39 observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"] - observations_history: [] + observations_history: [] # 0 is the latest observation clip_obs: 100.0 clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991, -1.1391, -1.0491, -2.0991, -0.4391, -1.3991] diff --git a/src/rl_sar/models/gr1t2/legged_gym/config.yaml b/src/rl_sar/models/gr1t2/legged_gym/config.yaml index 195b9dd..3add096 100644 --- a/src/rl_sar/models/gr1t2/legged_gym/config.yaml +++ b/src/rl_sar/models/gr1t2/legged_gym/config.yaml @@ -8,7 +8,7 @@ gr1t2/legged_gym: decimation: 20 num_observations: 39 observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"] - observations_history: [] + observations_history: [] # 0 is the latest observation clip_obs: 100.0 clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991, -1.1391, -1.0491, -2.0991, -0.4391, -1.3991] diff --git a/src/rl_sar/models/l4w4/legged_gym/config.yaml b/src/rl_sar/models/l4w4/legged_gym/config.yaml index f4ffe05..1dc777a 100644 --- a/src/rl_sar/models/l4w4/legged_gym/config.yaml +++ b/src/rl_sar/models/l4w4/legged_gym/config.yaml @@ -10,7 +10,7 @@ l4w4/legged_gym: decimation: 4 num_observations: 57 observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"] - observations_history: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + observations_history: [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] # 0 is the latest observation clip_obs: 100.0 clip_actions_lower: [-100.0, -100.0, -100.0, -100.0, -100.0, -100.0, -100.0, -100.0, diff --git a/src/rl_sar/scripts/observation_buffer.py b/src/rl_sar/scripts/observation_buffer.py index 9293dd5..2b010a8 100644 --- a/src/rl_sar/scripts/observation_buffer.py +++ b/src/rl_sar/scripts/observation_buffer.py @@ -29,12 +29,12 @@ class ObservationBuffer: Arguments: obs_ids: An array of integers with which to index the desired - observations, where 0 is the latest observation and - include_history_steps - 1 is the oldest observation. + observations, where 0 is the latest observation and + include_history_steps - 1 is the oldest observation. """ obs = [] - for obs_id in reversed(obs_ids): + for obs_id in obs_ids: slice_idx = self.include_history_steps - obs_id - 1 obs.append(self.obs_buf[:, slice_idx * self.num_obs : (slice_idx + 1) * self.num_obs]) return torch.cat(obs, dim=-1)