walk-these-ways-go2/go1_gym_learn/ppo/__init__.py

299 lines
13 KiB
Python
Raw Normal View History

2024-01-28 17:11:38 +08:00
# License: see [LICENSE, LICENSES/rsl_rl/LICENSE]
import time
from collections import deque
import torch
from ml_logger import logger
from params_proto import PrefixProto
import os
import copy
from .actor_critic import ActorCritic
from .rollout_storage import RolloutStorage
from go1_gym import MINI_GYM_ROOT_DIR
def class_to_dict(obj) -> dict:
if not hasattr(obj, "__dict__"):
return obj
result = {}
for key in dir(obj):
if key.startswith("_") or key == "terrain":
continue
element = []
val = getattr(obj, key)
if isinstance(val, list):
for item in val:
element.append(class_to_dict(item))
else:
element = class_to_dict(val)
result[key] = element
return result
class DataCaches:
def __init__(self, curriculum_bins):
from go1_gym_learn.ppo.metrics_caches import DistCache, SlotCache
self.slot_cache = SlotCache(curriculum_bins)
self.dist_cache = DistCache()
caches = DataCaches(1)
class RunnerArgs(PrefixProto, cli=False):
# runner
algorithm_class_name = 'PPO'
num_steps_per_env = 24 # per iteration
max_iterations = 1500 # number of policy updates
# logging
save_interval = 400 # check for potential saves every this many iterations
save_video_interval = 100
log_freq = 10
# load and resume
resume = False
load_run = -1 # -1 = last run
checkpoint = -1 # -1 = last saved model
resume_path = None # updated from load_run and chkpt
class Runner:
def __init__(self, env, device='cpu'):
from .ppo import PPO
self.device = device
self.env = env
actor_critic = ActorCritic(self.env.num_obs,
self.env.num_privileged_obs,
self.env.num_obs_history,
self.env.num_actions,
).to(self.device)
self.alg = PPO(actor_critic, device=self.device)
self.num_steps_per_env = RunnerArgs.num_steps_per_env
# init storage and model
self.alg.init_storage(self.env.num_train_envs, self.num_steps_per_env, [self.env.num_obs],
[self.env.num_privileged_obs], [self.env.num_obs_history], [self.env.num_actions])
self.tot_timesteps = 0
self.tot_time = 0
self.current_learning_iteration = 0
self.last_recording_it = 0
self.env.reset()
def learn(self, num_learning_iterations, init_at_random_ep_len=False, eval_freq=100, eval_expert=False):
from ml_logger import logger
# initialize writer
assert logger.prefix, "you will overwrite the entire instrument server"
logger.start('start', 'epoch', 'episode', 'run', 'step')
if init_at_random_ep_len:
self.env.episode_length_buf = torch.randint_like(self.env.episode_length_buf,
high=int(self.env.max_episode_length))
# split train and test envs
num_train_envs = self.env.num_train_envs
obs_dict = self.env.get_observations()
obs, privileged_obs, obs_history = obs_dict["obs"], obs_dict["privileged_obs"], obs_dict["obs_history"]
obs, privileged_obs, obs_history = obs.to(self.device), privileged_obs.to(self.device), obs_history.to(
self.device)
self.alg.actor_critic.train()
rewbuffer = deque(maxlen=100)
lenbuffer = deque(maxlen=100)
rewbuffer_eval = deque(maxlen=100)
lenbuffer_eval = deque(maxlen=100)
cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
if hasattr(self.env, "curriculum"):
caches.__init__(curriculum_bins=len(self.env.curriculum))
tot_iter = self.current_learning_iteration + num_learning_iterations
for it in range(self.current_learning_iteration, tot_iter):
start = time.time()
# Rollout
with torch.inference_mode():
for i in range(self.num_steps_per_env):
actions_train = self.alg.act(obs[:num_train_envs], privileged_obs[:num_train_envs],
obs_history[:num_train_envs])
if eval_expert:
actions_eval = self.alg.actor_critic.act_teacher(obs[num_train_envs:],
privileged_obs[num_train_envs:])
else:
actions_eval = self.alg.actor_critic.act_student(obs[num_train_envs:],
obs_history[num_train_envs:])
ret = self.env.step(torch.cat((actions_train, actions_eval), dim=0))
obs_dict, rewards, dones, infos = ret
obs, privileged_obs, obs_history = obs_dict["obs"], obs_dict["privileged_obs"], obs_dict[
"obs_history"]
obs, privileged_obs, obs_history, rewards, dones = obs.to(self.device), privileged_obs.to(
self.device), obs_history.to(self.device), rewards.to(self.device), dones.to(self.device)
self.alg.process_env_step(rewards[:num_train_envs], dones[:num_train_envs], infos)
if 'train/episode' in infos:
with logger.Prefix(metrics="train/episode"):
logger.store_metrics(**infos['train/episode'])
if 'eval/episode' in infos:
with logger.Prefix(metrics="eval/episode"):
logger.store_metrics(**infos['eval/episode'])
if 'curriculum' in infos:
curr_bins_train = infos['curriculum']['reset_train_env_bins']
curr_bins_eval = infos['curriculum']['reset_eval_env_bins']
caches.slot_cache.log(curr_bins_train, **{
k.split("/", 1)[-1]: v for k, v in infos['curriculum'].items()
if k.startswith('slot/train')
})
caches.slot_cache.log(curr_bins_eval, **{
k.split("/", 1)[-1]: v for k, v in infos['curriculum'].items()
if k.startswith('slot/eval')
})
caches.dist_cache.log(**{
k.split("/", 1)[-1]: v for k, v in infos['curriculum'].items()
if k.startswith('dist/train')
})
caches.dist_cache.log(**{
k.split("/", 1)[-1]: v for k, v in infos['curriculum'].items()
if k.startswith('dist/eval')
})
cur_reward_sum += rewards
cur_episode_length += 1
new_ids = (dones > 0).nonzero(as_tuple=False)
new_ids_train = new_ids[new_ids < num_train_envs]
rewbuffer.extend(cur_reward_sum[new_ids_train].cpu().numpy().tolist())
lenbuffer.extend(cur_episode_length[new_ids_train].cpu().numpy().tolist())
cur_reward_sum[new_ids_train] = 0
cur_episode_length[new_ids_train] = 0
new_ids_eval = new_ids[new_ids >= num_train_envs]
rewbuffer_eval.extend(cur_reward_sum[new_ids_eval].cpu().numpy().tolist())
lenbuffer_eval.extend(cur_episode_length[new_ids_eval].cpu().numpy().tolist())
cur_reward_sum[new_ids_eval] = 0
cur_episode_length[new_ids_eval] = 0
# Learning step
self.alg.compute_returns(obs[:num_train_envs], privileged_obs[:num_train_envs])
if it % eval_freq == 0:
self.env.reset_evaluation_envs()
if it % eval_freq == 0:
logger.save_pkl({"iteration": it,
**caches.slot_cache.get_summary(),
**caches.dist_cache.get_summary()},
path=f"curriculum/info.pkl", append=True)
mean_value_loss, mean_surrogate_loss, mean_adaptation_module_loss = self.alg.update()
logger.store_metrics(
time_elapsed=logger.since('start'),
time_iter=logger.split('epoch'),
adaptation_loss=mean_adaptation_module_loss,
mean_value_loss=mean_value_loss,
mean_surrogate_loss=mean_surrogate_loss
)
if RunnerArgs.save_video_interval:
self.log_video(it)
self.tot_timesteps += self.num_steps_per_env * self.env.num_envs
if logger.every(RunnerArgs.log_freq, "iteration", start_on=1):
# if it % Config.log_freq == 0:
logger.log_metrics_summary(key_values={"timesteps": self.tot_timesteps, "iterations": it})
logger.job_running()
if it % RunnerArgs.save_interval == 0:
with logger.Sync():
logger.torch_save(self.alg.actor_critic.state_dict(), f"checkpoints/ac_weights_{it:06d}.pt")
logger.duplicate(f"checkpoints/ac_weights_{it:06d}.pt", f"checkpoints/ac_weights_last.pt")
path = f'{MINI_GYM_ROOT_DIR}/tmp/legged_data'
os.makedirs(path, exist_ok=True)
adaptation_module_path = f'{path}/adaptation_module_latest.jit'
adaptation_module = copy.deepcopy(self.alg.actor_critic.adaptation_module).to('cpu')
traced_script_adaptation_module = torch.jit.script(adaptation_module)
traced_script_adaptation_module.save(adaptation_module_path)
body_path = f'{path}/body_latest.jit'
body_model = copy.deepcopy(self.alg.actor_critic.actor_body).to('cpu')
traced_script_body_module = torch.jit.script(body_model)
traced_script_body_module.save(body_path)
logger.upload_file(file_path=adaptation_module_path, target_path=f"checkpoints/", once=False)
logger.upload_file(file_path=body_path, target_path=f"checkpoints/", once=False)
self.current_learning_iteration += num_learning_iterations
with logger.Sync():
logger.torch_save(self.alg.actor_critic.state_dict(), f"checkpoints/ac_weights_{it:06d}.pt")
logger.duplicate(f"checkpoints/ac_weights_{it:06d}.pt", f"checkpoints/ac_weights_last.pt")
path = f'{MINI_GYM_ROOT_DIR}/tmp/legged_data'
os.makedirs(path, exist_ok=True)
adaptation_module_path = f'{path}/adaptation_module_latest.jit'
adaptation_module = copy.deepcopy(self.alg.actor_critic.adaptation_module).to('cpu')
traced_script_adaptation_module = torch.jit.script(adaptation_module)
traced_script_adaptation_module.save(adaptation_module_path)
body_path = f'{path}/body_latest.jit'
body_model = copy.deepcopy(self.alg.actor_critic.actor_body).to('cpu')
traced_script_body_module = torch.jit.script(body_model)
traced_script_body_module.save(body_path)
logger.upload_file(file_path=adaptation_module_path, target_path=f"checkpoints/", once=False)
logger.upload_file(file_path=body_path, target_path=f"checkpoints/", once=False)
def log_video(self, it):
if it - self.last_recording_it >= RunnerArgs.save_video_interval:
self.env.start_recording()
if self.env.num_eval_envs > 0:
self.env.start_recording_eval()
print("START RECORDING")
self.last_recording_it = it
frames = self.env.get_complete_frames()
if len(frames) > 0:
self.env.pause_recording()
print("LOGGING VIDEO")
logger.save_video(frames, f"videos/{it:05d}.mp4", fps=1 / self.env.dt)
if self.env.num_eval_envs > 0:
frames = self.env.get_complete_frames_eval()
if len(frames) > 0:
self.env.pause_recording_eval()
print("LOGGING EVAL VIDEO")
logger.save_video(frames, f"videos/{it:05d}_eval.mp4", fps=1 / self.env.dt)
def get_inference_policy(self, device=None):
self.alg.actor_critic.eval()
if device is not None:
self.alg.actor_critic.to(device)
return self.alg.actor_critic.act_inference
def get_expert_policy(self, device=None):
self.alg.actor_critic.eval()
if device is not None:
self.alg.actor_critic.to(device)
return self.alg.actor_critic.act_expert