mirror of https://github.com/fan-ziqi/rl_sar.git
fix history buffer
This commit is contained in:
parent
b066b9092e
commit
05c38eedba
|
@ -38,10 +38,18 @@ void ObservationBuffer::insert(torch::Tensor new_obs)
|
|||
obs_buf.index({torch::indexing::Slice(torch::indexing::None), torch::indexing::Slice(-num_obs, torch::indexing::None)}) = new_obs;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gets history of observations indexed by obs_ids.
|
||||
*
|
||||
* @param obs_ids An array of integers with which to index the desired
|
||||
* observations, where 0 is the latest observation and
|
||||
* include_history_steps - 1 is the oldest observation.
|
||||
* @return A torch::Tensor containing the concatenated observations.
|
||||
*/
|
||||
torch::Tensor ObservationBuffer::get_obs_vec(std::vector<int> obs_ids)
|
||||
{
|
||||
std::vector<torch::Tensor> obs;
|
||||
for (int i = obs_ids.size() - 1; i >= 0; --i)
|
||||
for (int i = 0; i < obs_ids.size(); ++i)
|
||||
{
|
||||
int obs_id = obs_ids[i];
|
||||
int slice_idx = include_history_steps - obs_id - 1;
|
||||
|
|
|
@ -8,7 +8,7 @@ a1/legged_gym:
|
|||
decimation: 4
|
||||
num_observations: 45
|
||||
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
|
||||
observations_history: [0, 1, 2, 3, 4, 5]
|
||||
observations_history: [5, 4, 3, 2, 1, 0] # 0 is the latest observation
|
||||
clip_obs: 100.0
|
||||
clip_actions_lower: [-100.0, -100.0, -100.0,
|
||||
-100.0, -100.0, -100.0,
|
||||
|
|
|
@ -8,7 +8,7 @@ a1/robot_lab:
|
|||
decimation: 4
|
||||
num_observations: 45
|
||||
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
|
||||
observations_history: []
|
||||
observations_history: [] # 0 is the latest observation
|
||||
clip_obs: 100.0
|
||||
clip_actions_lower: [-100.0, -100.0, -100.0,
|
||||
-100.0, -100.0, -100.0,
|
||||
|
|
|
@ -8,7 +8,7 @@ b2/robot_lab:
|
|||
decimation: 4
|
||||
num_observations: 45
|
||||
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
|
||||
observations_history: []
|
||||
observations_history: [] # 0 is the latest observation
|
||||
clip_obs: 100.0
|
||||
clip_actions_lower: [-100.0, -100.0, -100.0,
|
||||
-100.0, -100.0, -100.0,
|
||||
|
|
|
@ -5,7 +5,7 @@ b2w/robot_lab:
|
|||
decimation: 4
|
||||
num_observations: 57
|
||||
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
|
||||
observations_history: []
|
||||
observations_history: [] # 0 is the latest observation
|
||||
clip_obs: 100.0
|
||||
clip_actions_lower: [-100.0, -100.0, -100.0,
|
||||
-100.0, -100.0, -100.0,
|
||||
|
|
|
@ -8,7 +8,7 @@ go2/himloco:
|
|||
decimation: 4
|
||||
num_observations: 45
|
||||
observations: ["commands", "ang_vel", "gravity_vec", "dof_pos", "dof_vel", "actions"]
|
||||
observations_history: [5, 4, 3, 2, 1, 0]
|
||||
observations_history: [0, 1, 2, 3, 4, 5] # 0 is the latest observation
|
||||
clip_obs: 100.0
|
||||
clip_actions_lower: [-100, -100, -100,
|
||||
-100, -100, -100,
|
||||
|
|
|
@ -8,7 +8,7 @@ go2/robot_lab:
|
|||
decimation: 4
|
||||
num_observations: 45
|
||||
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
|
||||
observations_history: []
|
||||
observations_history: [] # 0 is the latest observation
|
||||
clip_obs: 100.0
|
||||
clip_actions_lower: [-100.0, -100.0, -100.0,
|
||||
-100.0, -100.0, -100.0,
|
||||
|
|
|
@ -5,7 +5,7 @@ go2w/robot_lab:
|
|||
decimation: 4
|
||||
num_observations: 57
|
||||
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
|
||||
observations_history: []
|
||||
observations_history: [] # 0 is the latest observation
|
||||
clip_obs: 100.0
|
||||
clip_actions_lower: [-100.0, -100.0, -100.0,
|
||||
-100.0, -100.0, -100.0,
|
||||
|
|
|
@ -8,7 +8,7 @@ gr1t1/legged_gym:
|
|||
decimation: 20
|
||||
num_observations: 39
|
||||
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
|
||||
observations_history: []
|
||||
observations_history: [] # 0 is the latest observation
|
||||
clip_obs: 100.0
|
||||
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
|
||||
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991]
|
||||
|
|
|
@ -8,7 +8,7 @@ gr1t2/legged_gym:
|
|||
decimation: 20
|
||||
num_observations: 39
|
||||
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
|
||||
observations_history: []
|
||||
observations_history: [] # 0 is the latest observation
|
||||
clip_obs: 100.0
|
||||
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
|
||||
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991]
|
||||
|
|
|
@ -10,7 +10,7 @@ l4w4/legged_gym:
|
|||
decimation: 4
|
||||
num_observations: 57
|
||||
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
|
||||
observations_history: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
observations_history: [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] # 0 is the latest observation
|
||||
clip_obs: 100.0
|
||||
clip_actions_lower: [-100.0, -100.0, -100.0, -100.0,
|
||||
-100.0, -100.0, -100.0, -100.0,
|
||||
|
|
|
@ -29,12 +29,12 @@ class ObservationBuffer:
|
|||
|
||||
Arguments:
|
||||
obs_ids: An array of integers with which to index the desired
|
||||
observations, where 0 is the latest observation and
|
||||
include_history_steps - 1 is the oldest observation.
|
||||
observations, where 0 is the latest observation and
|
||||
include_history_steps - 1 is the oldest observation.
|
||||
"""
|
||||
|
||||
obs = []
|
||||
for obs_id in reversed(obs_ids):
|
||||
for obs_id in obs_ids:
|
||||
slice_idx = self.include_history_steps - obs_id - 1
|
||||
obs.append(self.obs_buf[:, slice_idx * self.num_obs : (slice_idx + 1) * self.num_obs])
|
||||
return torch.cat(obs, dim=-1)
|
||||
|
|
Loading…
Reference in New Issue