walk-these-ways-go2/go2_gym_learn/eval_metrics/metrics.py

100 lines
2.6 KiB
Python
Raw Normal View History

2024-01-28 17:11:38 +08:00
def to_numpy(fn):
def thunk(*args, **kwargs):
return fn(*args, **kwargs).cpu().numpy()
return thunk
def lin_vel_rmsd(env, actor_critic, obs):
return ((env.base_lin_vel[:, 0] - env.commands[:, 0]) ** 2).cpu() ** 0.5
def ang_vel_rmsd(env, actor_critic, obs):
return ((env.base_ang_vel[:, 2] - env.commands[:, 2]) ** 2).cpu() ** 0.5
def lin_vel_x(env, actor_critic, obs):
return env.base_lin_vel[:, 0].cpu()
def ang_vel_yaw(env, actor_critic, obs):
return env.base_ang_vel[:, 2].cpu()
def base_height(env, actor_critic, obs):
import torch
return torch.mean(env.root_states[:, 2].unsqueeze(1) - env.measured_heights, dim=1).cpu()
def max_torques(env, actor_critic, obs):
import torch
max_torque, max_torque_indices = torch.max(torch.abs(env.torques), dim=1)
return max_torque.cpu()
def power_consumption(env, actor_critic, obs):
import torch
return torch.sum(torch.multiply(env.torques, env.dof_vel), dim=1).cpu()
def CoT(env, actor_critic, obs):
# P / (mgv)
import torch
P = power_consumption(env, actor_critic, obs)
m = (env.default_body_mass + env.payloads).cpu()
g = 9.8 # m/s^2
v = torch.norm(env.base_lin_vel[:, 0:2], dim=1).cpu()
return P / (m * g * v)
def froude_number(env, actor_critic, obs):
# v^2 / (gh)
v = lin_vel_x(env, actor_critic, obs)
g = 9.8
h = 0.30
return v ** 2 / (g * h)
def adaptation_loss(env, actor_critic, obs):
import torch
if hasattr(actor_critic, "adaptation_module"):
pred = actor_critic.adaptation_module(obs["obs_history"])
target = actor_critic.env_factor_encoder(obs["privileged_obs"])
return torch.mean((pred.cpu().detach() - target.cpu().detach()) ** 2, dim=1)
def auxiliary_rewards(env, actor_critic, obs):
rewards = {}
for i in range(len(env.reward_functions)):
name = env.reward_names[i]
rew = env.reward_functions[i]() * env.reward_scales[name]
rewards[name] = rew.cpu().detach()
return rewards
def termination(env, actor_critic, obs):
return env.reset_buf.cpu().detach()
def privileged_obs(env, actor_critic, obs):
return obs["privileged_obs"].cpu().numpy()
def latents(env, actor_critic, obs):
return actor_critic.env_factor_encoder(obs["privileged_obs"]).cpu().numpy()
METRICS_FNS = {name: fn for name, fn in locals().items() if name not in ['to_numpy'] and "__" not in name}
if __name__ == '__main__':
print(*METRICS_FNS.items(), sep="\n")
import torch
env = lambda: None
env.base_lin_vel = torch.rand(10, 3)
env.commands = torch.rand(10, 3)
metric = lin_vel_rmsd(env, None, None)
print(metric)