fix history buffer

This commit is contained in:
fan-ziqi 2025-03-03 16:38:06 +08:00
parent b066b9092e
commit 05c38eedba
12 changed files with 22 additions and 14 deletions

View File

@ -38,10 +38,18 @@ void ObservationBuffer::insert(torch::Tensor new_obs)
obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(-num_obs, torch::indexing::None)}) = new_obs;
}
/**
* @brief Gets history of observations indexed by obs_ids.
*
* @param obs_ids An array of integers with which to index the desired
* observations, where 0 is the latest observation and
* include_history_steps - 1 is the oldest observation.
* @return A torch::Tensor containing the concatenated observations.
*/
torch::Tensor ObservationBuffer::get_obs_vec(std::vector<int> obs_ids)
{
std::vector<torch::Tensor> obs;
for (int i = obs_ids.size() - 1; i >= 0; --i)
for (int i = 0; i < obs_ids.size(); ++i)
{
int obs_id = obs_ids[i];
int slice_idx = include_history_steps - obs_id - 1;

View File

@ -8,7 +8,7 @@ a1/legged_gym:
decimation: 4
num_observations: 45
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
observations_history: [0, 1, 2, 3, 4, 5]
observations_history: [5, 4, 3, 2, 1, 0] # 0 is the latest observation
clip_obs: 100.0
clip_actions_lower: [-100.0, -100.0, -100.0,
-100.0, -100.0, -100.0,

View File

@ -8,7 +8,7 @@ a1/robot_lab:
decimation: 4
num_observations: 45
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
observations_history: []
observations_history: [] # 0 is the latest observation
clip_obs: 100.0
clip_actions_lower: [-100.0, -100.0, -100.0,
-100.0, -100.0, -100.0,

View File

@ -8,7 +8,7 @@ b2/robot_lab:
decimation: 4
num_observations: 45
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
observations_history: []
observations_history: [] # 0 is the latest observation
clip_obs: 100.0
clip_actions_lower: [-100.0, -100.0, -100.0,
-100.0, -100.0, -100.0,

View File

@ -5,7 +5,7 @@ b2w/robot_lab:
decimation: 4
num_observations: 57
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
observations_history: []
observations_history: [] # 0 is the latest observation
clip_obs: 100.0
clip_actions_lower: [-100.0, -100.0, -100.0,
-100.0, -100.0, -100.0,

View File

@ -8,7 +8,7 @@ go2/himloco:
decimation: 4
num_observations: 45
observations: ["commands", "ang_vel", "gravity_vec", "dof_pos", "dof_vel", "actions"]
observations_history: [5, 4, 3, 2, 1, 0]
observations_history: [0, 1, 2, 3, 4, 5] # 0 is the latest observation
clip_obs: 100.0
clip_actions_lower: [-100, -100, -100,
-100, -100, -100,

View File

@ -8,7 +8,7 @@ go2/robot_lab:
decimation: 4
num_observations: 45
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
observations_history: []
observations_history: [] # 0 is the latest observation
clip_obs: 100.0
clip_actions_lower: [-100.0, -100.0, -100.0,
-100.0, -100.0, -100.0,

View File

@ -5,7 +5,7 @@ go2w/robot_lab:
decimation: 4
num_observations: 57
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
observations_history: []
observations_history: [] # 0 is the latest observation
clip_obs: 100.0
clip_actions_lower: [-100.0, -100.0, -100.0,
-100.0, -100.0, -100.0,

View File

@ -8,7 +8,7 @@ gr1t1/legged_gym:
decimation: 20
num_observations: 39
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
observations_history: []
observations_history: [] # 0 is the latest observation
clip_obs: 100.0
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991]

View File

@ -8,7 +8,7 @@ gr1t2/legged_gym:
decimation: 20
num_observations: 39
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
observations_history: []
observations_history: [] # 0 is the latest observation
clip_obs: 100.0
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991]

View File

@ -10,7 +10,7 @@ l4w4/legged_gym:
decimation: 4
num_observations: 57
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
observations_history: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
observations_history: [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] # 0 is the latest observation
clip_obs: 100.0
clip_actions_lower: [-100.0, -100.0, -100.0, -100.0,
-100.0, -100.0, -100.0, -100.0,

View File

@ -29,12 +29,12 @@ class ObservationBuffer:
Arguments:
obs_ids: An array of integers with which to index the desired
observations, where 0 is the latest observation and
include_history_steps - 1 is the oldest observation.
observations, where 0 is the latest observation and
include_history_steps - 1 is the oldest observation.
"""
obs = []
for obs_id in reversed(obs_ids):
for obs_id in obs_ids:
slice_idx = self.include_history_steps - obs_id - 1
obs.append(self.obs_buf[:, slice_idx * self.num_obs : (slice_idx + 1) * self.num_obs])
return torch.cat(obs, dim=-1)