mirror of https://github.com/fan-ziqi/rl_sar.git
Merge branch 'main' into devel
This commit is contained in:
commit
760dcfd55b
15
README.md
15
README.md
|
@ -118,7 +118,20 @@ rosrun rl_sar rl_real_a1
|
|||
|
||||
Press the **R2** button on the controller to switch the robot to the default standing position, press **R1** to switch to RL control mode, and press **L2** in any state to switch to the initial lying position. The left stick controls x-axis up and down, controls yaw left and right, and the right stick controls y-axis left and right.
|
||||
|
||||
OR Press **0** on the keyboard to switch the robot to the default standing position, press **P** to switch to RL control mode, and press **1** in any state to switch to the initial lying position. WS controls x-axis, AD controls yaw, and JL controls y-axis.
|
||||
Or press **0** on the keyboard to switch the robot to the default standing position, press **P** to switch to RL control mode, and press **1** in any state to switch to the initial lying position. WS controls x-axis, AD controls yaw, and JL controls y-axis.
|
||||
|
||||
### Train the actuator network
|
||||
|
||||
1. Uncomment `#define CSV_LOGGER` in the top of `rl_real.cpp`. You can also modify the corresponding part in the simulation program to collect simulation data for testing the training process.
|
||||
2. Run the control program, and the program will log all data after execution.
|
||||
3. Stop the control program and start training the actuator network. Note that `rl_sar/src/rl_sar/models/` is omitted before the following paths.
|
||||
```bash
|
||||
rosrun rl_sar actuator_net.py --mode train --data a1/motor.csv --output a1/motor.pt
|
||||
```
|
||||
4. Verify the trained actuator network.
|
||||
```bash
|
||||
rosrun rl_sar actuator_net.py --mode play --data a1/motor.csv --output a1/motor.pt
|
||||
```
|
||||
|
||||
## Add Your Robot
|
||||
|
||||
|
|
15
README_CN.md
15
README_CN.md
|
@ -119,7 +119,20 @@ rosrun rl_sar rl_real_a1
|
|||
|
||||
按下遥控器的**R2**键让机器人切换到默认站起姿态,按下**R1**键切换到RL控制模式,任意状态按下**L2**切换到最初的趴下姿态。左摇杆上下控制x左右控制yaw,右摇杆左右控制y。
|
||||
|
||||
OR 按下键盘上的**0**键让机器人切换到默认站起姿态,按下**P**键切换到RL控制模式,任意状态按下**1**键切换到最初的趴下姿态。WS控制x,AD控制yaw,JL控制y。
|
||||
或者按下键盘上的**0**键让机器人切换到默认站起姿态,按下**P**键切换到RL控制模式,任意状态按下**1**键切换到最初的趴下姿态。WS控制x,AD控制yaw,JL控制y。
|
||||
|
||||
### 训练执行器网络
|
||||
|
||||
1. 取消注释`rl_real.cpp`中最上面的`#define CSV_LOGGER`,你也可以在仿真程序中修改对应部分采集仿真数据用来测试训练过程。
|
||||
2. 运行控制程序,程序会在执行后记录所有数据。
|
||||
3. 停止控制程序,开始训练执行器网络。注意,下面的路径前均省略了`rl_sar/src/rl_sar/models/`。
|
||||
```bash
|
||||
rosrun rl_sar actuator_net.py --mode train --data a1/motor.csv --output a1/motor.pt
|
||||
```
|
||||
4. 验证已经训练好的训练执行器网络。
|
||||
```bash
|
||||
rosrun rl_sar actuator_net.py --mode play --data a1/motor.csv --output a1/motor.pt
|
||||
```
|
||||
|
||||
## 添加你的机器人
|
||||
|
||||
|
|
|
@ -83,6 +83,6 @@ target_link_libraries(rl_real_a1
|
|||
|
||||
catkin_install_python(PROGRAMS
|
||||
scripts/rl_sim.py
|
||||
scripts/rl_sdk.py
|
||||
scripts/actuator_net.py
|
||||
DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
|
||||
)
|
|
@ -388,7 +388,7 @@ void RL::CSVInit(std::string robot_name)
|
|||
{
|
||||
csv_filename = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/motor";
|
||||
|
||||
// // Uncomment these lines if need timestamp for file name
|
||||
// Uncomment these lines if need timestamp for file name
|
||||
// auto now = std::chrono::system_clock::now();
|
||||
// std::time_t now_c = std::chrono::system_clock::to_time_t(now);
|
||||
// std::stringstream ss;
|
||||
|
@ -399,7 +399,7 @@ void RL::CSVInit(std::string robot_name)
|
|||
csv_filename += ".csv";
|
||||
std::ofstream file(csv_filename.c_str());
|
||||
|
||||
for(int i = 0; i < 12; ++i) {file << "torque_" << i << ",";}
|
||||
for(int i = 0; i < 12; ++i) {file << "tau_cal_" << i << ",";}
|
||||
for(int i = 0; i < 12; ++i) {file << "tau_est_" << i << ",";}
|
||||
for(int i = 0; i < 12; ++i) {file << "joint_pos_" << i << ",";}
|
||||
for(int i = 0; i < 12; ++i) {file << "joint_pos_target_" << i << ",";}
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import torch
|
||||
import yaml
|
||||
import os
|
||||
import csv
|
||||
from pynput import keyboard
|
||||
from enum import Enum, auto
|
||||
|
||||
CONFIG_PATH = os.path.join(os.path.dirname(__file__), "../config.yaml")
|
||||
BASE_PATH = os.path.join(os.path.dirname(__file__), "../../")
|
||||
CONFIG_PATH = os.path.join(BASE_PATH, "config.yaml")
|
||||
|
||||
class LOGGER:
|
||||
INFO = "\033[0;37m[INFO]\033[0m "
|
||||
|
@ -372,3 +374,37 @@ class RL:
|
|||
self.params.default_dof_pos = torch.tensor(self.ReadVectorFromYaml(config["default_dof_pos"], self.params.framework, rows, cols)).view(1, -1)
|
||||
self.params.joint_controller_names = self.ReadVectorFromYaml(config["joint_controller_names"], self.params.framework, rows, cols)
|
||||
|
||||
def CSVInit(self, robot_name):
|
||||
self.csv_filename = os.path.join(BASE_PATH, "models", robot_name, 'motor')
|
||||
|
||||
# Uncomment these lines if need timestamp for file name
|
||||
# now = datetime.now()
|
||||
# timestamp = now.strftime("%Y%m%d%H%M%S")
|
||||
# self.csv_filename += f"_{timestamp}"
|
||||
|
||||
self.csv_filename += ".csv"
|
||||
|
||||
with open(self.csv_filename, 'w', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
header = []
|
||||
header += [f"tau_cal_{i}" for i in range(12)]
|
||||
header += [f"tau_est_{i}" for i in range(12)]
|
||||
header += [f"joint_pos_{i}" for i in range(12)]
|
||||
header += [f"joint_pos_target_{i}" for i in range(12)]
|
||||
header += [f"joint_vel_{i}" for i in range(12)]
|
||||
|
||||
writer.writerow(header)
|
||||
|
||||
def CSVLogger(self, torque, tau_est, joint_pos, joint_pos_target, joint_vel):
|
||||
with open(self.csv_filename, 'a', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
row = []
|
||||
row += [torque[0][i].item() for i in range(12)]
|
||||
row += [tau_est[0][i].item() for i in range(12)]
|
||||
row += [joint_pos[0][i].item() for i in range(12)]
|
||||
row += [joint_pos_target[0][i].item() for i in range(12)]
|
||||
row += [joint_vel[0][i].item() for i in range(12)]
|
||||
|
||||
writer.writerow(row)
|
|
@ -0,0 +1,258 @@
|
|||
import os
|
||||
import argparse
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.optim import Adam
|
||||
import pandas as pd
|
||||
|
||||
BASE_PATH = os.path.join(os.path.dirname(__file__), "../")
|
||||
|
||||
class Config:
|
||||
def __init__(self):
|
||||
self.lr = 8e-4
|
||||
self.eps = 1e-8
|
||||
self.weight_decay = 0.0
|
||||
self.epochs = 200
|
||||
self.batch_size = 128
|
||||
self.device = "cuda:0"
|
||||
self.in_dim = 6
|
||||
self.units = 32
|
||||
self.layers = 2
|
||||
self.out_dim = 1
|
||||
self.act = "softsign"
|
||||
self.dt = 0.02
|
||||
|
||||
class ActuatorDataset(Dataset):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data["joint_states"])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {k: v[idx] for k, v in self.data.items()}
|
||||
|
||||
class Act(nn.Module):
|
||||
def __init__(self, act, slope=0.05):
|
||||
super(Act, self).__init__()
|
||||
self.act = act
|
||||
self.slope = slope
|
||||
self.shift = torch.log(torch.tensor(2.0)).item()
|
||||
|
||||
def forward(self, input):
|
||||
if self.act == "relu":
|
||||
return F.relu(input)
|
||||
elif self.act == "leaky_relu":
|
||||
return F.leaky_relu(input)
|
||||
elif self.act == "sp":
|
||||
return F.softplus(input, beta=1.0)
|
||||
elif self.act == "leaky_sp":
|
||||
return F.softplus(input, beta=1.0) - self.slope * F.relu(-input)
|
||||
elif self.act == "elu":
|
||||
return F.elu(input, alpha=1.0)
|
||||
elif self.act == "leaky_elu":
|
||||
return F.elu(input, alpha=1.0) - self.slope * F.relu(-input)
|
||||
elif self.act == "ssp":
|
||||
return F.softplus(input, beta=1.0) - self.shift
|
||||
elif self.act == "leaky_ssp":
|
||||
return (
|
||||
F.softplus(input, beta=1.0) - self.slope * F.relu(-input) - self.shift
|
||||
)
|
||||
elif self.act == "tanh":
|
||||
return torch.tanh(input)
|
||||
elif self.act == "leaky_tanh":
|
||||
return torch.tanh(input) + self.slope * input
|
||||
elif self.act == "swish":
|
||||
return torch.sigmoid(input) * input
|
||||
elif self.act == "softsign":
|
||||
return F.softsign(input)
|
||||
else:
|
||||
raise RuntimeError(f"Undefined activation called {self.act}")
|
||||
|
||||
def build_mlp(config):
|
||||
mods = [nn.Linear(config.in_dim, config.units), Act(config.act)]
|
||||
for i in range(config.layers - 1):
|
||||
mods += [nn.Linear(config.units, config.units), Act(config.act)]
|
||||
mods += [nn.Linear(config.units, config.out_dim)]
|
||||
return nn.Sequential(*mods)
|
||||
|
||||
def load_data(data_path):
|
||||
data = pd.read_csv(data_path)
|
||||
if len(data) < 1:
|
||||
return None, 0
|
||||
|
||||
num_motors = sum(1 for col in data.columns if col.startswith("tau_est_"))
|
||||
columns = ["tau_est_", "tau_cal_", "joint_pos_", "joint_pos_target_", "joint_vel_"]
|
||||
|
||||
data_dict = {col: [] for col in columns}
|
||||
for col in columns:
|
||||
for i in range(num_motors):
|
||||
data_dict[col].append(data[f"{col}{i}"].values)
|
||||
|
||||
for key in data_dict.keys():
|
||||
data_dict[key] = np.array(data_dict[key]).T
|
||||
|
||||
return data_dict, num_motors
|
||||
|
||||
def process_data(data_dict, num_motors, step):
|
||||
joint_position_errors = data_dict["joint_pos_target_"] - data_dict["joint_pos_"]
|
||||
joint_velocities = data_dict["joint_vel_"]
|
||||
tau_ests = data_dict["tau_est_"]
|
||||
|
||||
joint_position_errors = torch.tensor(joint_position_errors, dtype=torch.float)
|
||||
joint_velocities = torch.tensor(joint_velocities, dtype=torch.float)
|
||||
tau_ests = torch.tensor(tau_ests, dtype=torch.float)
|
||||
|
||||
xs, ys = [], []
|
||||
for i in range(num_motors):
|
||||
xs_joint = [
|
||||
joint_position_errors [step: , i:i+1],
|
||||
joint_position_errors [step-1:-1, i:i+1],
|
||||
joint_position_errors [step-2:-2, i:i+1],
|
||||
joint_velocities [step: , i:i+1],
|
||||
joint_velocities [step-1:-1, i:i+1],
|
||||
joint_velocities [step-2:-2, i:i+1],
|
||||
]
|
||||
tau_ests_joint = tau_ests[step: , i:i+1]
|
||||
|
||||
xs_joint = torch.cat(xs_joint, dim=1)
|
||||
xs.append(xs_joint)
|
||||
ys.append(tau_ests_joint)
|
||||
|
||||
xs = torch.cat(xs, dim=0)
|
||||
ys = torch.cat(ys, dim=0)
|
||||
return xs, ys
|
||||
|
||||
def train_actuator_network(xs, ys, actuator_network_path, config):
|
||||
num_data = xs.shape[0]
|
||||
num_train = num_data // 5 * 4
|
||||
num_test = num_data - num_train
|
||||
|
||||
dataset = ActuatorDataset({"joint_states": xs, "tau_ests": ys})
|
||||
train_set, val_set = torch.utils.data.random_split(dataset, [num_train, num_test])
|
||||
train_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True)
|
||||
test_loader = DataLoader(val_set, batch_size=config.batch_size, shuffle=True)
|
||||
|
||||
model = build_mlp(config)
|
||||
|
||||
opt = Adam(model.parameters(), lr=config.lr, eps=config.eps, weight_decay=config.weight_decay)
|
||||
|
||||
model = model.to(config.device)
|
||||
for epoch in range(config.epochs):
|
||||
epoch_loss = 0
|
||||
ct = 0
|
||||
for batch in train_loader:
|
||||
data = batch["joint_states"].to(config.device)
|
||||
y_pred = model(data)
|
||||
|
||||
opt.zero_grad()
|
||||
|
||||
y_label = batch["tau_ests"].to(config.device)
|
||||
|
||||
tau_est_loss = ((y_pred - y_label) ** 2).mean()
|
||||
loss = tau_est_loss
|
||||
|
||||
loss.backward()
|
||||
opt.step()
|
||||
epoch_loss += loss.detach().cpu().numpy()
|
||||
ct += 1
|
||||
epoch_loss /= ct
|
||||
|
||||
test_loss = 0
|
||||
mae = 0
|
||||
ct = 0
|
||||
if epoch % 1 == 0:
|
||||
with torch.no_grad():
|
||||
for batch in test_loader:
|
||||
data = batch["joint_states"].to(config.device)
|
||||
y_pred = model(data)
|
||||
|
||||
y_label = batch["tau_ests"].to(config.device)
|
||||
|
||||
tau_est_loss = ((y_pred - y_label) ** 2).mean()
|
||||
loss = tau_est_loss
|
||||
test_mae = (y_pred - y_label).abs().mean()
|
||||
|
||||
test_loss += loss
|
||||
mae += test_mae
|
||||
ct += 1
|
||||
test_loss /= ct
|
||||
mae /= ct
|
||||
|
||||
print(f"epoch: {epoch} | loss: {epoch_loss:.4f} | test loss: {test_loss:.4f} | mae: {mae:.4f}")
|
||||
|
||||
model_scripted = torch.jit.script(model) # Export to TorchScript
|
||||
model_scripted.save(actuator_network_path) # Save
|
||||
return model
|
||||
|
||||
def train_actuator_network_and_plot_predictions(data_path, actuator_network_path, load_pretrained_model=False, config=None):
|
||||
|
||||
print(f"Load data: {data_path}")
|
||||
|
||||
data_dict, num_motors = load_data(data_path)
|
||||
if data_dict is None:
|
||||
print(f"Failed to load data from {data_path}")
|
||||
return
|
||||
step = 2
|
||||
xs, ys = process_data(data_dict, num_motors, step)
|
||||
|
||||
if load_pretrained_model:
|
||||
print("Evaluating the existing actuator network...")
|
||||
model = torch.jit.load(actuator_network_path).to("cpu")
|
||||
print(f"Use trained actuator network model: {actuator_network_path}")
|
||||
else:
|
||||
print("Training a new actuator network...")
|
||||
model = train_actuator_network(xs, ys, actuator_network_path, config).to("cpu")
|
||||
print(f"Saving actuator network model to: {actuator_network_path}")
|
||||
|
||||
tau_preds = model(xs).detach().reshape(num_motors, -1).T
|
||||
|
||||
plot_length = 1000
|
||||
|
||||
timesteps = np.array(range(len(data_dict["tau_est_"]))) * config.dt
|
||||
timesteps = timesteps[step:step + len(tau_preds)]
|
||||
|
||||
tau_cals = data_dict["tau_cal_"][step:step + len(tau_preds)]
|
||||
tau_ests = data_dict["tau_est_"][step:step + len(tau_preds)]
|
||||
tau_preds = tau_preds[:plot_length]
|
||||
|
||||
fig, axs = plt.subplots(6, 2, figsize=(14, 6))
|
||||
axs = np.array(axs).flatten()
|
||||
for i in range(num_motors):
|
||||
axs[i].plot(timesteps[:plot_length], tau_cals[:plot_length, i], label="Calculated torque")
|
||||
axs[i].plot(timesteps[:plot_length], tau_ests[:plot_length, i], label="Real torque")
|
||||
axs[i].plot(timesteps[:plot_length], tau_preds[:plot_length, i], label="Predicted torque", linestyle="--")
|
||||
fig.legend(["Calculated torque", "Real torque", "Predicted torque"], loc='upper right', bbox_to_anchor=(1, 1))
|
||||
plt.show()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--mode", type=str, required=True, choices=["train", "play"], help="Choose whether to train or evaluate the actuator network")
|
||||
parser.add_argument("--data", type=str, required=True, help="Path of data files")
|
||||
parser.add_argument("--output", type=str, required=True, help="Path to save or load the actuator network model")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
data_path = os.path.join(BASE_PATH, "models", args.data)
|
||||
output_path = os.path.join(BASE_PATH, "models", args.output)
|
||||
|
||||
config = Config()
|
||||
|
||||
if args.mode == "train":
|
||||
load_pretrained_model = False
|
||||
elif args.mode == "play":
|
||||
load_pretrained_model = True
|
||||
|
||||
train_actuator_network_and_plot_predictions(
|
||||
data_path=data_path,
|
||||
actuator_network_path=output_path,
|
||||
load_pretrained_model=load_pretrained_model,
|
||||
config=config,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -13,9 +13,11 @@ from gazebo_msgs.srv import SetModelState, SetModelStateRequest
|
|||
from std_srvs.srv import Empty
|
||||
|
||||
path = os.path.abspath(".")
|
||||
sys.path.insert(0, path + "/src/rl_sar/scripts")
|
||||
from rl_sdk import *
|
||||
from observation_buffer import *
|
||||
sys.path.insert(0, path + "/src/rl_sar")
|
||||
from library.rl_sdk.rl_sdk import *
|
||||
from library.observation_buffer.observation_buffer import *
|
||||
|
||||
CSV_LOGGER = False
|
||||
|
||||
class RL_Sim(RL):
|
||||
def __init__(self):
|
||||
|
@ -87,6 +89,10 @@ class RL_Sim(RL):
|
|||
self.listener_keyboard = keyboard.Listener(on_press=self.KeyboardInterface)
|
||||
self.listener_keyboard.start()
|
||||
|
||||
# others
|
||||
if CSV_LOGGER:
|
||||
self.CSVInit(self.robot_name)
|
||||
|
||||
print(LOGGER.INFO + "RL_Sim start")
|
||||
|
||||
def __del__(self):
|
||||
|
@ -190,6 +196,10 @@ class RL_Sim(RL):
|
|||
self.output_torques = torch.clamp(origin_output_torques, -(self.params.torque_limits), self.params.torque_limits)
|
||||
self.output_dof_pos = self.ComputePosition(self.obs.actions)
|
||||
|
||||
if CSV_LOGGER:
|
||||
tau_est = torch.tensor(self.mapped_joint_efforts).unsqueeze(0)
|
||||
self.CSVLogger(self.output_torques, tau_est, self.obs.dof_pos, self.output_dof_pos, self.obs.dof_vel)
|
||||
|
||||
def ComputeObservation(self):
|
||||
obs = torch.cat([
|
||||
# self.obs.lin_vel * self.params.lin_vel_scale,
|
||||
|
|
Loading…
Reference in New Issue