feat: remove the `use_history` field, add `observations_history` to support history observations, and adjust the related logic to handle the case of empty history observations

This commit is contained in:
fan-ziqi 2024-10-12 23:45:18 +08:00
parent 11ec554365
commit 274db72d4b
14 changed files with 151 additions and 119 deletions

View File

@ -421,11 +421,18 @@ void RL::ReadYaml(std::string robot_name)
this->params.framework = config["framework"].as<std::string>(); this->params.framework = config["framework"].as<std::string>();
int rows = config["rows"].as<int>(); int rows = config["rows"].as<int>();
int cols = config["cols"].as<int>(); int cols = config["cols"].as<int>();
this->params.use_history = config["use_history"].as<bool>();
this->params.dt = config["dt"].as<double>(); this->params.dt = config["dt"].as<double>();
this->params.decimation = config["decimation"].as<int>(); this->params.decimation = config["decimation"].as<int>();
this->params.num_observations = config["num_observations"].as<int>(); this->params.num_observations = config["num_observations"].as<int>();
this->params.observations = ReadVectorFromYaml<std::string>(config["observations"]); this->params.observations = ReadVectorFromYaml<std::string>(config["observations"]);
if (config["observations_history"].IsNull())
{
this->params.observations_history = {};
}
else
{
this->params.observations_history = ReadVectorFromYaml<int>(config["observations_history"]);
}
this->params.clip_obs = config["clip_obs"].as<double>(); this->params.clip_obs = config["clip_obs"].as<double>();
if (config["clip_actions_lower"].IsNull() && config["clip_actions_upper"].IsNull()) if (config["clip_actions_lower"].IsNull() && config["clip_actions_upper"].IsNull())
{ {

View File

@ -72,11 +72,11 @@ struct ModelParams
{ {
std::string model_name; std::string model_name;
std::string framework; std::string framework;
bool use_history;
double dt; double dt;
int decimation; int decimation;
int num_observations; int num_observations;
std::vector<std::string> observations; std::vector<std::string> observations;
std::vector<int> observations_history;
double damping; double damping;
double stiffness; double stiffness;
double action_scale; double action_scale;

View File

@ -3,11 +3,11 @@ a1_isaacgym:
framework: "isaacgym" framework: "isaacgym"
rows: 4 rows: 4
cols: 3 cols: 3
use_history: True
dt: 0.005 dt: 0.005
decimation: 4 decimation: 4
num_observations: 45 num_observations: 45
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"] observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
observations_history: [0, 1, 2, 3, 4, 5]
clip_obs: 100.0 clip_obs: 100.0
clip_actions_lower: [-100, -100, -100, clip_actions_lower: [-100, -100, -100,
-100, -100, -100, -100, -100, -100,

View File

@ -3,11 +3,11 @@ a1_isaacsim:
framework: "isaacsim" framework: "isaacsim"
rows: 4 rows: 4
cols: 3 cols: 3
use_history: False
dt: 0.005 dt: 0.005
decimation: 4 decimation: 4
num_observations: 45 num_observations: 45
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"] observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
observations_history: []
clip_obs: 100.0 clip_obs: 100.0
clip_actions_lower: [-100, -100, -100, clip_actions_lower: [-100, -100, -100,
-100, -100, -100, -100, -100, -100,

View File

@ -3,11 +3,11 @@ go2_isaacgym:
framework: "isaacgym" framework: "isaacgym"
rows: 4 rows: 4
cols: 3 cols: 3
use_history: True
dt: 0.005 dt: 0.005
decimation: 4 decimation: 4
num_observations: 45 num_observations: 45
observations: ["commands", "ang_vel", "gravity_vec", "dof_pos", "dof_vel", "actions"] observations: ["commands", "ang_vel", "gravity_vec", "dof_pos", "dof_vel", "actions"]
observations_history: [5, 4, 3, 2, 1, 0]
clip_obs: 100.0 clip_obs: 100.0
clip_actions_lower: [-100, -100, -100, clip_actions_lower: [-100, -100, -100,
-100, -100, -100, -100, -100, -100,

View File

@ -3,11 +3,11 @@ gr1t1_isaacgym:
framework: "isaacgym" framework: "isaacgym"
rows: 2 rows: 2
cols: 5 cols: 5
use_history: False
dt: 0.001 dt: 0.001
decimation: 20 decimation: 20
num_observations: 39 num_observations: 39
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"] observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
observations_history: []
clip_obs: 100.0 clip_obs: 100.0
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991, clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991] -1.1391, -1.0491, -2.0991, -0.4391, -1.3991]

View File

@ -3,11 +3,11 @@ gr1t1_isaacsim:
framework: "isaacsim" framework: "isaacsim"
rows: 2 rows: 2
cols: 5 cols: 5
use_history: False
dt: 0.001 dt: 0.001
decimation: 20 decimation: 20
num_observations: 39 num_observations: 39
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"] observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
observations_history: []
clip_obs: 100.0 clip_obs: 100.0
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991, clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991] -1.1391, -1.0491, -2.0991, -0.4391, -1.3991]

View File

@ -3,11 +3,11 @@ gr1t2_isaacgym:
framework: "isaacgym" framework: "isaacgym"
rows: 2 rows: 2
cols: 5 cols: 5
use_history: False
dt: 0.001 dt: 0.001
decimation: 20 decimation: 20
num_observations: 39 num_observations: 39
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"] observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
observations_history: []
clip_obs: 100.0 clip_obs: 100.0
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991, clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991] -1.1391, -1.0491, -2.0991, -0.4391, -1.3991]

View File

@ -3,11 +3,11 @@ gr1t2_isaacsim:
framework: "isaacsim" framework: "isaacsim"
rows: 2 rows: 2
cols: 5 cols: 5
use_history: False
dt: 0.001 dt: 0.001
decimation: 20 decimation: 20
num_observations: 39 num_observations: 39
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"] observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
observations_history: []
clip_obs: 100.0 clip_obs: 100.0
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991, clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991] -1.1391, -1.0491, -2.0991, -0.4391, -1.3991]

View File

@ -64,11 +64,11 @@ class ModelParams:
def __init__(self): def __init__(self):
self.model_name = None self.model_name = None
self.framework = None self.framework = None
self.use_history = None
self.dt = None self.dt = None
self.decimation = None self.decimation = None
self.num_observations = None self.num_observations = None
self.observations = None self.observations = None
self.observations_history = None
self.damping = None self.damping = None
self.stiffness = None self.stiffness = None
self.action_scale = None self.action_scale = None
@ -378,16 +378,20 @@ class RL:
self.params.framework = config["framework"] self.params.framework = config["framework"]
rows = config["rows"] rows = config["rows"]
cols = config["cols"] cols = config["cols"]
self.params.use_history = config["use_history"]
self.params.dt = config["dt"] self.params.dt = config["dt"]
self.params.decimation = config["decimation"] self.params.decimation = config["decimation"]
self.params.num_observations = config["num_observations"] self.params.num_observations = config["num_observations"]
self.params.observations = config["observations"] self.params.observations = config["observations"]
if config["observations_history"] is None:
self.params.observations_history = None
else:
self.params.observations_history = config["observations_history"]
self.params.clip_obs = config["clip_obs"] self.params.clip_obs = config["clip_obs"]
self.params.action_scale = config["action_scale"] self.params.action_scale = config["action_scale"]
self.params.hip_scale_reduction = config["hip_scale_reduction"] self.params.hip_scale_reduction = config["hip_scale_reduction"]
self.params.hip_scale_reduction_indices = config["hip_scale_reduction_indices"] self.params.hip_scale_reduction_indices = config["hip_scale_reduction_indices"]
if config["clip_actions_upper"] is None and config["clip_actions_upper"] is None: if config["clip_actions_lower"] is None and config["clip_actions_upper"] is None:
self.params.clip_actions_upper = None self.params.clip_actions_upper = None
self.params.clip_actions_lower = None self.params.clip_actions_lower = None
else: else:

View File

@ -36,8 +36,8 @@ class RL_Sim(RL):
self.ReadYaml(self.robot_name) self.ReadYaml(self.robot_name)
# history # history
if self.params.use_history: if self.params.observations_history is None:
self.history_obs_buf = ObservationBuffer(1, self.params.num_observations, 6) self.history_obs_buf = ObservationBuffer(1, self.params.num_observations, len(self.params.observations_history))
# Due to the fact that the robot_state_publisher sorts the joint names alphabetically, # Due to the fact that the robot_state_publisher sorts the joint names alphabetically,
# the mapping table is established according to the order defined in the YAML file # the mapping table is established according to the order defined in the YAML file
@ -202,9 +202,9 @@ class RL_Sim(RL):
def Forward(self): def Forward(self):
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
clamped_obs = self.ComputeObservation() clamped_obs = self.ComputeObservation()
if self.params.use_history: if self.params.observations_history is None:
self.history_obs_buf.insert(clamped_obs) self.history_obs_buf.insert(clamped_obs)
history_obs = self.history_obs_buf.get_obs_vec(np.arange(6)) history_obs = self.history_obs_buf.get_obs_vec(self.params.observations_history)
actions = self.model.forward(history_obs) actions = self.model.forward(history_obs)
else: else:
actions = self.model.forward(clamped_obs) actions = self.model.forward(clamped_obs)

View File

@ -12,7 +12,10 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
this->ReadYaml(this->robot_name); this->ReadYaml(this->robot_name);
// history // history
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); if (!this->params.observations_history.empty())
{
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, this->params.observations_history.size());
}
this->unitree_udp.InitCmdData(this->unitree_low_command); this->unitree_udp.InitCmdData(this->unitree_low_command);
@ -176,10 +179,17 @@ torch::Tensor RL_Real::Forward()
torch::Tensor clamped_obs = this->ComputeObservation(); torch::Tensor clamped_obs = this->ComputeObservation();
torch::Tensor actions;
if (!this->params.observations_history.empty())
{
this->history_obs_buf.insert(clamped_obs); this->history_obs_buf.insert(clamped_obs);
this->history_obs = this->history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5}); this->history_obs = this->history_obs_buf.get_obs_vec(this->params.observations_history);
actions = this->model.forward({this->history_obs}).toTensor();
torch::Tensor actions = this->model.forward({this->history_obs}).toTensor(); }
else
{
actions = this->model.forward({clamped_obs}).toTensor();
}
if (this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0) if (this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0)
{ {

View File

@ -12,7 +12,10 @@ void RL_Real::RL_Real()
this->ReadYaml(this->robot_name); this->ReadYaml(this->robot_name);
// history // history
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); if (!this->params.observations_history.empty())
{
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, this->params.observations_history.size());
}
// init robot // init robot
this->InitRobotStateClient(); this->InitRobotStateClient();
@ -187,10 +190,17 @@ torch::Tensor RL_Real::Forward()
torch::Tensor clamped_obs = this->ComputeObservation(); torch::Tensor clamped_obs = this->ComputeObservation();
torch::Tensor actions;
if (!this->params.observations_history.empty())
{
this->history_obs_buf.insert(clamped_obs); this->history_obs_buf.insert(clamped_obs);
this->history_obs = this->history_obs_buf.get_obs_vec({5, 4, 3, 2, 1, 0}); this->history_obs = this->history_obs_buf.get_obs_vec(this->params.observations_history);
actions = this->model.forward({this->history_obs}).toTensor();
torch::Tensor actions = this->model.forward({this->history_obs}).toTensor(); }
else
{
actions = this->model.forward({clamped_obs}).toTensor();
}
if (this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0) if (this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0)
{ {

View File

@ -12,9 +12,9 @@ RL_Sim::RL_Sim()
this->ReadYaml(this->robot_name); this->ReadYaml(this->robot_name);
// history // history
if (this->params.use_history) if (!this->params.observations_history.empty())
{ {
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, this->params.observations_history.size());
} }
// Due to the fact that the robot_state_publisher sorts the joint names alphabetically, // Due to the fact that the robot_state_publisher sorts the joint names alphabetically,
@ -247,13 +247,14 @@ void RL_Sim::RunModel()
torch::Tensor RL_Sim::Forward() torch::Tensor RL_Sim::Forward()
{ {
torch::autograd::GradMode::set_enabled(false); torch::autograd::GradMode::set_enabled(false);
torch::Tensor clamped_obs = this->ComputeObservation(); torch::Tensor clamped_obs = this->ComputeObservation();
torch::Tensor actions; torch::Tensor actions;
if (this->params.use_history) if (!this->params.observations_history.empty())
{ {
this->history_obs_buf.insert(clamped_obs); this->history_obs_buf.insert(clamped_obs);
// TODO-devel-go2 这里要找一种方法适配不同的顺序不能直接改这里会导致a1的模型不可用 this->history_obs = this->history_obs_buf.get_obs_vec(this->params.observations_history);
this->history_obs = this->history_obs_buf.get_obs_vec({5, 4, 3, 2, 1, 0});
actions = this->model.forward({this->history_obs}).toTensor(); actions = this->model.forward({this->history_obs}).toTensor();
} }
else else