mirror of https://github.com/fan-ziqi/rl_sar.git
feat: add obs list in config.yaml and move ComputeObservation to rl_sdk
This commit is contained in:
parent
4540198171
commit
1d94debeef
14
README.md
14
README.md
|
@ -57,14 +57,7 @@ sudo ldconfig
|
||||||
|
|
||||||
## Compilation
|
## Compilation
|
||||||
|
|
||||||
Customize the following two functions in your code to adapt to different models:
|
Compile in the root directory of the project
|
||||||
|
|
||||||
```cpp
|
|
||||||
torch::Tensor forward() override;
|
|
||||||
torch::Tensor compute_observation() override;
|
|
||||||
```
|
|
||||||
|
|
||||||
Then compile in the root directory
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd ..
|
cd ..
|
||||||
|
@ -142,8 +135,9 @@ In the following text, `<ROBOT>` represents the name of the robot
|
||||||
1. Create a model package named `<ROBOT>_description` in the `rl_sar/src/robots` directory. Place the robot's URDF file in the `rl_sar/src/robots/<ROBOT>_description/urdf` directory and name it `<ROBOT>.urdf`. Additionally, create a joint configuration file with the namespace `<ROBOT>_gazebo` in the `rl_sar/src/robots/<ROBOT>_description/config` directory.
|
1. Create a model package named `<ROBOT>_description` in the `rl_sar/src/robots` directory. Place the robot's URDF file in the `rl_sar/src/robots/<ROBOT>_description/urdf` directory and name it `<ROBOT>.urdf`. Additionally, create a joint configuration file with the namespace `<ROBOT>_gazebo` in the `rl_sar/src/robots/<ROBOT>_description/config` directory.
|
||||||
2. Place the trained RL model files in the `rl_sar/src/rl_sar/models/<ROBOT>` directory.
|
2. Place the trained RL model files in the `rl_sar/src/rl_sar/models/<ROBOT>` directory.
|
||||||
3. In the `rl_sar/src/rl_sar/models/<ROBOT>` directory, create a `config.yaml` file, and modify its parameters based on the `rl_sar/src/rl_sar/models/a1_isaacgym/config.yaml` file.
|
3. In the `rl_sar/src/rl_sar/models/<ROBOT>` directory, create a `config.yaml` file, and modify its parameters based on the `rl_sar/src/rl_sar/models/a1_isaacgym/config.yaml` file.
|
||||||
4. If you need to run simulations, modify the launch files as needed by referring to those in the `rl_sar/src/rl_sar/launch` directory.
|
4. Modify the `forward()` function in the code as needed to adapt to different models.
|
||||||
5. If you need to run on the physical robot, modify the file `rl_sar/src/rl_sar/src/rl_real_a1.cpp` as needed.
|
5. If you need to run simulations, modify the launch files as needed by referring to those in the `rl_sar/src/rl_sar/launch` directory.
|
||||||
|
6. If you need to run on the physical robot, modify the file `rl_sar/src/rl_sar/src/rl_real_a1.cpp` as needed.
|
||||||
|
|
||||||
## Reference
|
## Reference
|
||||||
|
|
||||||
|
|
14
README_CN.md
14
README_CN.md
|
@ -57,14 +57,7 @@ sudo ldconfig
|
||||||
|
|
||||||
## 编译
|
## 编译
|
||||||
|
|
||||||
自定义代码中的以下两个函数,以适配不同的模型:
|
在项目根目录编译
|
||||||
|
|
||||||
```cpp
|
|
||||||
torch::Tensor forward() override;
|
|
||||||
torch::Tensor compute_observation() override;
|
|
||||||
```
|
|
||||||
|
|
||||||
然后到根目录编译
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd ..
|
cd ..
|
||||||
|
@ -143,8 +136,9 @@ rosrun rl_sar rl_real_a1
|
||||||
1. 在`rl_sar/src/robots`路径下创建名为`<ROBOT>_description`的模型包,将模型的urdf放到`rl_sar/src/robots/<ROBOT>_description/urdf`路径下并命名为`<ROBOT>.urdf`,并在`rl_sar/src/robots/<ROBOT>_description/config`路径下创建命名空间为`<ROBOT>_gazebo`的关节配置文件
|
1. 在`rl_sar/src/robots`路径下创建名为`<ROBOT>_description`的模型包,将模型的urdf放到`rl_sar/src/robots/<ROBOT>_description/urdf`路径下并命名为`<ROBOT>.urdf`,并在`rl_sar/src/robots/<ROBOT>_description/config`路径下创建命名空间为`<ROBOT>_gazebo`的关节配置文件
|
||||||
2. 将训练好的RL模型文件放到`rl_sar/src/rl_sar/models/<ROBOT>`路径下
|
2. 将训练好的RL模型文件放到`rl_sar/src/rl_sar/models/<ROBOT>`路径下
|
||||||
3. 在`rl_sar/src/rl_sar/models/<ROBOT>`中新建config.yaml文件,参考`rl_sar/src/rl_sar/models/a1_isaacgym/config.yaml`文件修改其中参数
|
3. 在`rl_sar/src/rl_sar/models/<ROBOT>`中新建config.yaml文件,参考`rl_sar/src/rl_sar/models/a1_isaacgym/config.yaml`文件修改其中参数
|
||||||
4. 若需要运行仿真,则参考`rl_sar/src/rl_sar/launch`路径下的launch文件自行修改
|
4. 按需修改代码中的`forward()`函数,以适配不同的模型
|
||||||
5. 若需要运行实物,则参考`rl_sar/src/rl_sar/src/rl_real_a1.cpp`文件自行修改
|
5. 若需要运行仿真,则参考`rl_sar/src/rl_sar/launch`路径下的launch文件自行修改
|
||||||
|
6. 若需要运行实物,则参考`rl_sar/src/rl_sar/src/rl_real_a1.cpp`文件自行修改
|
||||||
|
|
||||||
## 参考
|
## 参考
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,6 @@ public:
|
||||||
private:
|
private:
|
||||||
// rl functions
|
// rl functions
|
||||||
torch::Tensor Forward() override;
|
torch::Tensor Forward() override;
|
||||||
torch::Tensor ComputeObservation() override;
|
|
||||||
void GetState(RobotState<double> *state) override;
|
void GetState(RobotState<double> *state) override;
|
||||||
void SetCommand(const RobotCommand<double> *command) override;
|
void SetCommand(const RobotCommand<double> *command) override;
|
||||||
void RunModel();
|
void RunModel();
|
||||||
|
|
|
@ -24,7 +24,6 @@ public:
|
||||||
private:
|
private:
|
||||||
// rl functions
|
// rl functions
|
||||||
torch::Tensor Forward() override;
|
torch::Tensor Forward() override;
|
||||||
torch::Tensor ComputeObservation() override;
|
|
||||||
void GetState(RobotState<double> *state) override;
|
void GetState(RobotState<double> *state) override;
|
||||||
void SetCommand(const RobotCommand<double> *command) override;
|
void SetCommand(const RobotCommand<double> *command) override;
|
||||||
void RunModel();
|
void RunModel();
|
||||||
|
|
|
@ -1,21 +1,5 @@
|
||||||
#include "rl_sdk.hpp"
|
#include "rl_sdk.hpp"
|
||||||
|
|
||||||
/* You may need to override this ComputeObservation() function
|
|
||||||
torch::Tensor RL_XXX::ComputeObservation()
|
|
||||||
{
|
|
||||||
torch::Tensor obs = torch::cat({
|
|
||||||
this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel, this->params.framework) * this->params.ang_vel_scale,
|
|
||||||
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec, this->params.framework),
|
|
||||||
this->obs.commands * this->params.commands_scale,
|
|
||||||
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
|
|
||||||
this->obs.dof_vel * this->params.dof_vel_scale,
|
|
||||||
this->obs.actions
|
|
||||||
},1);
|
|
||||||
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
|
|
||||||
return clamped_obs;
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
/* You may need to override this Forward() function
|
/* You may need to override this Forward() function
|
||||||
torch::Tensor RL_XXX::Forward()
|
torch::Tensor RL_XXX::Forward()
|
||||||
{
|
{
|
||||||
|
@ -27,6 +11,48 @@ torch::Tensor RL_XXX::Forward()
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
torch::Tensor RL::ComputeObservation()
|
||||||
|
{
|
||||||
|
std::vector<torch::Tensor> obs_list;
|
||||||
|
|
||||||
|
for(const std::string& observation : this->params.observations)
|
||||||
|
{
|
||||||
|
if(observation == "lin_vel")
|
||||||
|
{
|
||||||
|
obs_list.push_back(this->obs.lin_vel * this->params.lin_vel_scale);
|
||||||
|
}
|
||||||
|
else if(observation == "ang_vel")
|
||||||
|
{
|
||||||
|
// obs_list.push_back(this->obs.ang_vel * this->params.ang_vel_scale); // TODO is QuatRotateInverse necessery?
|
||||||
|
obs_list.push_back(this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel, this->params.framework) * this->params.ang_vel_scale);
|
||||||
|
}
|
||||||
|
else if(observation == "gravity_vec")
|
||||||
|
{
|
||||||
|
obs_list.push_back(this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec, this->params.framework));
|
||||||
|
}
|
||||||
|
else if(observation == "commands")
|
||||||
|
{
|
||||||
|
obs_list.push_back(this->obs.commands * this->params.commands_scale);
|
||||||
|
}
|
||||||
|
else if(observation == "dof_pos")
|
||||||
|
{
|
||||||
|
obs_list.push_back((this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale);
|
||||||
|
}
|
||||||
|
else if(observation == "dof_vel")
|
||||||
|
{
|
||||||
|
obs_list.push_back(this->obs.dof_vel * this->params.dof_vel_scale);
|
||||||
|
}
|
||||||
|
else if(observation == "actions")
|
||||||
|
{
|
||||||
|
obs_list.push_back(this->obs.actions);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor obs = torch::cat(obs_list, 1);
|
||||||
|
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
|
||||||
|
return clamped_obs;
|
||||||
|
}
|
||||||
|
|
||||||
void RL::InitObservations()
|
void RL::InitObservations()
|
||||||
{
|
{
|
||||||
this->obs.lin_vel = torch::tensor({{0.0, 0.0, 0.0}});
|
this->obs.lin_vel = torch::tensor({{0.0, 0.0, 0.0}});
|
||||||
|
@ -369,6 +395,7 @@ void RL::ReadYaml(std::string robot_name)
|
||||||
this->params.dt = config["dt"].as<double>();
|
this->params.dt = config["dt"].as<double>();
|
||||||
this->params.decimation = config["decimation"].as<int>();
|
this->params.decimation = config["decimation"].as<int>();
|
||||||
this->params.num_observations = config["num_observations"].as<int>();
|
this->params.num_observations = config["num_observations"].as<int>();
|
||||||
|
this->params.observations = ReadVectorFromYaml<std::string>(config["observations"]);
|
||||||
this->params.clip_obs = config["clip_obs"].as<double>();
|
this->params.clip_obs = config["clip_obs"].as<double>();
|
||||||
this->params.clip_actions_upper = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_upper"], this->params.framework, rows, cols)).view({1, -1});
|
this->params.clip_actions_upper = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_upper"], this->params.framework, rows, cols)).view({1, -1});
|
||||||
this->params.clip_actions_lower = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_lower"], this->params.framework, rows, cols)).view({1, -1});
|
this->params.clip_actions_lower = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_lower"], this->params.framework, rows, cols)).view({1, -1});
|
||||||
|
|
|
@ -74,6 +74,7 @@ struct ModelParams
|
||||||
double dt;
|
double dt;
|
||||||
int decimation;
|
int decimation;
|
||||||
int num_observations;
|
int num_observations;
|
||||||
|
std::vector<std::string> observations;
|
||||||
double damping;
|
double damping;
|
||||||
double stiffness;
|
double stiffness;
|
||||||
double action_scale;
|
double action_scale;
|
||||||
|
@ -128,7 +129,7 @@ public:
|
||||||
|
|
||||||
// rl functions
|
// rl functions
|
||||||
virtual torch::Tensor Forward() = 0;
|
virtual torch::Tensor Forward() = 0;
|
||||||
virtual torch::Tensor ComputeObservation() = 0;
|
torch::Tensor ComputeObservation();
|
||||||
virtual void GetState(RobotState<double> *state) = 0;
|
virtual void GetState(RobotState<double> *state) = 0;
|
||||||
virtual void SetCommand(const RobotCommand<double> *command) = 0;
|
virtual void SetCommand(const RobotCommand<double> *command) = 0;
|
||||||
void StateController(const RobotState<double> *state, RobotCommand<double> *command);
|
void StateController(const RobotState<double> *state, RobotCommand<double> *command);
|
||||||
|
|
|
@ -7,6 +7,7 @@ a1_isaacgym:
|
||||||
dt: 0.005
|
dt: 0.005
|
||||||
decimation: 4
|
decimation: 4
|
||||||
num_observations: 45
|
num_observations: 45
|
||||||
|
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
|
||||||
clip_obs: 100.0
|
clip_obs: 100.0
|
||||||
clip_actions_lower: [-100, -100, -100,
|
clip_actions_lower: [-100, -100, -100,
|
||||||
-100, -100, -100,
|
-100, -100, -100,
|
||||||
|
|
|
@ -7,6 +7,7 @@ a1_isaacsim:
|
||||||
dt: 0.005
|
dt: 0.005
|
||||||
decimation: 4
|
decimation: 4
|
||||||
num_observations: 45
|
num_observations: 45
|
||||||
|
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
|
||||||
clip_obs: 100.0
|
clip_obs: 100.0
|
||||||
clip_actions_lower: [-100, -100, -100,
|
clip_actions_lower: [-100, -100, -100,
|
||||||
-100, -100, -100,
|
-100, -100, -100,
|
||||||
|
|
|
@ -7,6 +7,7 @@ gr1t1_isaacgym:
|
||||||
dt: 0.001
|
dt: 0.001
|
||||||
decimation: 20
|
decimation: 20
|
||||||
num_observations: 39
|
num_observations: 39
|
||||||
|
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
|
||||||
clip_obs: 100.0
|
clip_obs: 100.0
|
||||||
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
|
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
|
||||||
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991]
|
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991]
|
||||||
|
|
|
@ -7,6 +7,7 @@ gr1t1_isaacsim:
|
||||||
dt: 0.001
|
dt: 0.001
|
||||||
decimation: 20
|
decimation: 20
|
||||||
num_observations: 39
|
num_observations: 39
|
||||||
|
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
|
||||||
clip_obs: 100.0
|
clip_obs: 100.0
|
||||||
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
|
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
|
||||||
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991]
|
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991]
|
||||||
|
|
|
@ -7,6 +7,7 @@ gr1t2_isaacgym:
|
||||||
dt: 0.001
|
dt: 0.001
|
||||||
decimation: 20
|
decimation: 20
|
||||||
num_observations: 39
|
num_observations: 39
|
||||||
|
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
|
||||||
clip_obs: 100.0
|
clip_obs: 100.0
|
||||||
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
|
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
|
||||||
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991]
|
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991]
|
||||||
|
|
|
@ -7,6 +7,7 @@ gr1t2_isaacsim:
|
||||||
dt: 0.001
|
dt: 0.001
|
||||||
decimation: 20
|
decimation: 20
|
||||||
num_observations: 39
|
num_observations: 39
|
||||||
|
observations: ["ang_vel", "gravity_vec", "commands", "dof_pos", "dof_vel", "actions"]
|
||||||
clip_obs: 100.0
|
clip_obs: 100.0
|
||||||
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
|
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
|
||||||
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991]
|
-1.1391, -1.0491, -2.0991, -0.4391, -1.3991]
|
||||||
|
|
|
@ -68,6 +68,7 @@ class ModelParams:
|
||||||
self.dt = None
|
self.dt = None
|
||||||
self.decimation = None
|
self.decimation = None
|
||||||
self.num_observations = None
|
self.num_observations = None
|
||||||
|
self.observations = None
|
||||||
self.damping = None
|
self.damping = None
|
||||||
self.stiffness = None
|
self.stiffness = None
|
||||||
self.action_scale = None
|
self.action_scale = None
|
||||||
|
@ -134,6 +135,28 @@ class RL:
|
||||||
self.output_torques = torch.zeros(1, 32)
|
self.output_torques = torch.zeros(1, 32)
|
||||||
self.output_dof_pos = torch.zeros(1, 32)
|
self.output_dof_pos = torch.zeros(1, 32)
|
||||||
|
|
||||||
|
def ComputeObservation(self):
|
||||||
|
obs_list = []
|
||||||
|
for observation in self.params.observations:
|
||||||
|
if observation == "lin_vel":
|
||||||
|
obs_list.append(self.obs.lin_vel * self.params.lin_vel_scale)
|
||||||
|
elif observation == "ang_vel":
|
||||||
|
# obs_list.append(self.obs.ang_vel * self.params.ang_vel_scale) # TODO is QuatRotateInverse necessery?
|
||||||
|
obs_list.append(self.QuatRotateInverse(self.obs.base_quat, self.obs.ang_vel, self.params.framework) * self.params.ang_vel_scale)
|
||||||
|
elif observation == "gravity_vec":
|
||||||
|
obs_list.append(self.QuatRotateInverse(self.obs.base_quat, self.obs.gravity_vec, self.params.framework))
|
||||||
|
elif observation == "commands":
|
||||||
|
obs_list.append(self.obs.commands * self.params.commands_scale)
|
||||||
|
elif observation == "dof_pos":
|
||||||
|
obs_list.append((self.obs.dof_pos - self.params.default_dof_pos) * self.params.dof_pos_scale)
|
||||||
|
elif observation == "dof_vel":
|
||||||
|
obs_list.append(self.obs.dof_vel * self.params.dof_vel_scale)
|
||||||
|
elif observation == "actions":
|
||||||
|
obs_list.append(self.obs.actions)
|
||||||
|
obs = torch.cat(obs_list, dim=-1)
|
||||||
|
clamped_obs = torch.clamp(obs, -self.params.clip_obs, self.params.clip_obs)
|
||||||
|
return clamped_obs
|
||||||
|
|
||||||
def InitObservations(self):
|
def InitObservations(self):
|
||||||
self.obs.lin_vel = torch.zeros(1, 3, dtype=torch.float)
|
self.obs.lin_vel = torch.zeros(1, 3, dtype=torch.float)
|
||||||
self.obs.ang_vel = torch.zeros(1, 3, dtype=torch.float)
|
self.obs.ang_vel = torch.zeros(1, 3, dtype=torch.float)
|
||||||
|
@ -359,6 +382,7 @@ class RL:
|
||||||
self.params.dt = config["dt"]
|
self.params.dt = config["dt"]
|
||||||
self.params.decimation = config["decimation"]
|
self.params.decimation = config["decimation"]
|
||||||
self.params.num_observations = config["num_observations"]
|
self.params.num_observations = config["num_observations"]
|
||||||
|
self.params.observations = config["observations"]
|
||||||
self.params.clip_obs = config["clip_obs"]
|
self.params.clip_obs = config["clip_obs"]
|
||||||
self.params.action_scale = config["action_scale"]
|
self.params.action_scale = config["action_scale"]
|
||||||
self.params.hip_scale_reduction = config["hip_scale_reduction"]
|
self.params.hip_scale_reduction = config["hip_scale_reduction"]
|
||||||
|
|
|
@ -173,7 +173,7 @@ class RL_Sim(RL):
|
||||||
|
|
||||||
def RunModel(self):
|
def RunModel(self):
|
||||||
if self.running_state == STATE.STATE_RL_RUNNING and self.simulation_running:
|
if self.running_state == STATE.STATE_RL_RUNNING and self.simulation_running:
|
||||||
# self.obs.lin_vel = torch.tensor([[self.vel.linear.x, self.vel.linear.y, self.vel.linear.z]])
|
self.obs.lin_vel = torch.tensor([[self.vel.linear.x, self.vel.linear.y, self.vel.linear.z]])
|
||||||
self.obs.ang_vel = torch.tensor(self.robot_state.imu.gyroscope).unsqueeze(0)
|
self.obs.ang_vel = torch.tensor(self.robot_state.imu.gyroscope).unsqueeze(0)
|
||||||
# self.obs.commands = torch.tensor([[self.cmd_vel.linear.x, self.cmd_vel.linear.y, self.cmd_vel.angular.z]])
|
# self.obs.commands = torch.tensor([[self.cmd_vel.linear.x, self.cmd_vel.linear.y, self.cmd_vel.angular.z]])
|
||||||
self.obs.commands = torch.tensor([[self.control.x, self.control.y, self.control.yaw]])
|
self.obs.commands = torch.tensor([[self.control.x, self.control.y, self.control.yaw]])
|
||||||
|
@ -199,20 +199,6 @@ class RL_Sim(RL):
|
||||||
tau_est = torch.tensor(self.mapped_joint_efforts).unsqueeze(0)
|
tau_est = torch.tensor(self.mapped_joint_efforts).unsqueeze(0)
|
||||||
self.CSVLogger(self.output_torques, tau_est, self.obs.dof_pos, self.output_dof_pos, self.obs.dof_vel)
|
self.CSVLogger(self.output_torques, tau_est, self.obs.dof_pos, self.output_dof_pos, self.obs.dof_vel)
|
||||||
|
|
||||||
def ComputeObservation(self):
|
|
||||||
obs = torch.cat([
|
|
||||||
# self.obs.lin_vel * self.params.lin_vel_scale,
|
|
||||||
# self.obs.ang_vel * self.params.ang_vel_scale, # TODO is QuatRotateInverse necessery?
|
|
||||||
self.QuatRotateInverse(self.obs.base_quat, self.obs.ang_vel, self.params.framework) * self.params.ang_vel_scale,
|
|
||||||
self.QuatRotateInverse(self.obs.base_quat, self.obs.gravity_vec, self.params.framework),
|
|
||||||
self.obs.commands * self.params.commands_scale,
|
|
||||||
(self.obs.dof_pos - self.params.default_dof_pos) * self.params.dof_pos_scale,
|
|
||||||
self.obs.dof_vel * self.params.dof_vel_scale,
|
|
||||||
self.obs.actions
|
|
||||||
], dim = -1)
|
|
||||||
clamped_obs = torch.clamp(obs, -self.params.clip_obs, self.params.clip_obs)
|
|
||||||
return clamped_obs
|
|
||||||
|
|
||||||
def Forward(self):
|
def Forward(self):
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
clamped_obs = self.ComputeObservation()
|
clamped_obs = self.ComputeObservation()
|
||||||
|
|
|
@ -170,21 +170,6 @@ void RL_Real::RunModel()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor RL_Real::ComputeObservation()
|
|
||||||
{
|
|
||||||
torch::Tensor obs = torch::cat({
|
|
||||||
// this->QuatRotateInverse(this->obs.base_quat, this->obs.lin_vel) * this->params.lin_vel_scale,
|
|
||||||
this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel, this->params.framework) * this->params.ang_vel_scale,
|
|
||||||
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec, this->params.framework),
|
|
||||||
this->obs.commands * this->params.commands_scale,
|
|
||||||
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
|
|
||||||
this->obs.dof_vel * this->params.dof_vel_scale,
|
|
||||||
this->obs.actions
|
|
||||||
},1);
|
|
||||||
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
|
|
||||||
return clamped_obs;
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::Tensor RL_Real::Forward()
|
torch::Tensor RL_Real::Forward()
|
||||||
{
|
{
|
||||||
torch::autograd::GradMode::set_enabled(false);
|
torch::autograd::GradMode::set_enabled(false);
|
||||||
|
|
|
@ -212,7 +212,7 @@ void RL_Sim::RunModel()
|
||||||
{
|
{
|
||||||
if(this->running_state == STATE_RL_RUNNING && simulation_running)
|
if(this->running_state == STATE_RL_RUNNING && simulation_running)
|
||||||
{
|
{
|
||||||
// this->obs.lin_vel = torch::tensor({{this->vel.linear.x, this->vel.linear.y, this->vel.linear.z}});
|
this->obs.lin_vel = torch::tensor({{this->vel.linear.x, this->vel.linear.y, this->vel.linear.z}});
|
||||||
this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0);
|
this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0);
|
||||||
// this->obs.commands = torch::tensor({{this->cmd_vel.linear.x, this->cmd_vel.linear.y, this->cmd_vel.angular.z}});
|
// this->obs.commands = torch::tensor({{this->cmd_vel.linear.x, this->cmd_vel.linear.y, this->cmd_vel.angular.z}});
|
||||||
this->obs.commands = torch::tensor({{this->control.x, this->control.y, this->control.yaw}});
|
this->obs.commands = torch::tensor({{this->control.x, this->control.y, this->control.yaw}});
|
||||||
|
@ -243,22 +243,6 @@ void RL_Sim::RunModel()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor RL_Sim::ComputeObservation()
|
|
||||||
{
|
|
||||||
torch::Tensor obs = torch::cat({
|
|
||||||
// this->obs.lin_vel * this->params.lin_vel_scale,
|
|
||||||
// this->obs.ang_vel * this->params.ang_vel_scale, // TODO is QuatRotateInverse necessery?
|
|
||||||
this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel, this->params.framework) * this->params.ang_vel_scale,
|
|
||||||
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec, this->params.framework),
|
|
||||||
this->obs.commands * this->params.commands_scale,
|
|
||||||
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
|
|
||||||
this->obs.dof_vel * this->params.dof_vel_scale,
|
|
||||||
this->obs.actions
|
|
||||||
},1);
|
|
||||||
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
|
|
||||||
return clamped_obs;
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::Tensor RL_Sim::Forward()
|
torch::Tensor RL_Sim::Forward()
|
||||||
{
|
{
|
||||||
torch::autograd::GradMode::set_enabled(false);
|
torch::autograd::GradMode::set_enabled(false);
|
||||||
|
|
Loading…
Reference in New Issue