walk-these-ways-go2/go2_gym_learn/ppo_cse/__init__.py

309 lines
14 KiB
Python
Executable File

import time
from collections import deque
import copy
import os
import torch
from ml_logger import logger
from params_proto import PrefixProto
from .actor_critic import ActorCritic
from .rollout_storage import RolloutStorage
def class_to_dict(obj) -> dict:
if not hasattr(obj, "__dict__"):
return obj
result = {}
for key in dir(obj):
if key.startswith("_") or key == "terrain":
continue
element = []
val = getattr(obj, key)
if isinstance(val, list):
for item in val:
element.append(class_to_dict(item))
else:
element = class_to_dict(val)
result[key] = element
return result
class DataCaches:
def __init__(self, curriculum_bins):
from go2_gym_learn.ppo.metrics_caches import SlotCache, DistCache
self.slot_cache = SlotCache(curriculum_bins)
self.dist_cache = DistCache()
caches = DataCaches(1)
class RunnerArgs(PrefixProto, cli=False):
# runner
algorithm_class_name = 'RMA'
num_steps_per_env = 24 # per iteration
max_iterations = 1500 # number of policy updates
# logging
save_interval = 400 # check for potential saves every this many iterations
save_video_interval = 100
log_freq = 10
# load and resume
resume = False
load_run = -1 # -1 = last run
checkpoint = -1 # -1 = last saved model
resume_path = None # updated from load_run and chkpt
resume_curriculum = True
class Runner:
def __init__(self, env, device='cpu'):
from .ppo import PPO
self.device = device
self.env = env
actor_critic = ActorCritic(self.env.num_obs,
self.env.num_privileged_obs,
self.env.num_obs_history,
self.env.num_actions,
).to(self.device)
if RunnerArgs.resume:
# load pretrained weights from resume_path
from ml_logger import ML_Logger
loader = ML_Logger(root="http://escher.csail.mit.edu:8080",
prefix=RunnerArgs.resume_path)
weights = loader.load_torch("checkpoints/ac_weights_last.pt")
actor_critic.load_state_dict(state_dict=weights)
if hasattr(self.env, "curricula") and RunnerArgs.resume_curriculum:
# load curriculum state
distributions = loader.load_pkl("curriculum/distribution.pkl")
distribution_last = distributions[-1]["distribution"]
gait_names = [key[8:] if key.startswith("weights_") else None for key in distribution_last.keys()]
for gait_id, gait_name in enumerate(self.env.category_names):
self.env.curricula[gait_id].weights = distribution_last[f"weights_{gait_name}"]
print(gait_name)
self.alg = PPO(actor_critic, device=self.device)
self.num_steps_per_env = RunnerArgs.num_steps_per_env
# init storage and model
self.alg.init_storage(self.env.num_train_envs, self.num_steps_per_env, [self.env.num_obs],
[self.env.num_privileged_obs], [self.env.num_obs_history], [self.env.num_actions])
self.tot_timesteps = 0
self.tot_time = 0
self.current_learning_iteration = 0
self.last_recording_it = 0
self.env.reset()
def learn(self, num_learning_iterations, init_at_random_ep_len=False, eval_freq=100, curriculum_dump_freq=500, eval_expert=False):
from ml_logger import logger
# initialize writer
assert logger.prefix, "you will overwrite the entire instrument server"
logger.start('start', 'epoch', 'episode', 'run', 'step')
if init_at_random_ep_len:
self.env.episode_length_buf = torch.randint_like(self.env.episode_length_buf,
high=int(self.env.max_episode_length))
# split train and test envs
num_train_envs = self.env.num_train_envs
obs_dict = self.env.get_observations() # TODO: check, is this correct on the first step?
obs, privileged_obs, obs_history = obs_dict["obs"], obs_dict["privileged_obs"], obs_dict["obs_history"]
obs, privileged_obs, obs_history = obs.to(self.device), privileged_obs.to(self.device), obs_history.to(
self.device)
self.alg.actor_critic.train() # switch to train mode (for dropout for example)
rewbuffer = deque(maxlen=100)
lenbuffer = deque(maxlen=100)
rewbuffer_eval = deque(maxlen=100)
lenbuffer_eval = deque(maxlen=100)
cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
tot_iter = self.current_learning_iteration + num_learning_iterations
for it in range(self.current_learning_iteration, tot_iter):
start = time.time()
# Rollout
with torch.inference_mode():
for i in range(self.num_steps_per_env):
actions_train = self.alg.act(obs[:num_train_envs], privileged_obs[:num_train_envs],
obs_history[:num_train_envs])
if eval_expert:
actions_eval = self.alg.actor_critic.act_teacher(obs_history[num_train_envs:],
privileged_obs[num_train_envs:])
else:
actions_eval = self.alg.actor_critic.act_student(obs_history[num_train_envs:])
ret = self.env.step(torch.cat((actions_train, actions_eval), dim=0))
obs_dict, rewards, dones, infos = ret
obs, privileged_obs, obs_history = obs_dict["obs"], obs_dict["privileged_obs"], obs_dict[
"obs_history"]
obs, privileged_obs, obs_history, rewards, dones = obs.to(self.device), privileged_obs.to(
self.device), obs_history.to(self.device), rewards.to(self.device), dones.to(self.device)
self.alg.process_env_step(rewards[:num_train_envs], dones[:num_train_envs], infos)
if 'train/episode' in infos:
with logger.Prefix(metrics="train/episode"):
logger.store_metrics(**infos['train/episode'])
if 'eval/episode' in infos:
with logger.Prefix(metrics="eval/episode"):
logger.store_metrics(**infos['eval/episode'])
if 'curriculum' in infos:
cur_reward_sum += rewards
cur_episode_length += 1
new_ids = (dones > 0).nonzero(as_tuple=False)
new_ids_train = new_ids[new_ids < num_train_envs]
rewbuffer.extend(cur_reward_sum[new_ids_train].cpu().numpy().tolist())
lenbuffer.extend(cur_episode_length[new_ids_train].cpu().numpy().tolist())
cur_reward_sum[new_ids_train] = 0
cur_episode_length[new_ids_train] = 0
new_ids_eval = new_ids[new_ids >= num_train_envs]
rewbuffer_eval.extend(cur_reward_sum[new_ids_eval].cpu().numpy().tolist())
lenbuffer_eval.extend(cur_episode_length[new_ids_eval].cpu().numpy().tolist())
cur_reward_sum[new_ids_eval] = 0
cur_episode_length[new_ids_eval] = 0
if 'curriculum/distribution' in infos:
distribution = infos['curriculum/distribution']
stop = time.time()
collection_time = stop - start
# Learning step
start = stop
self.alg.compute_returns(obs_history[:num_train_envs], privileged_obs[:num_train_envs])
if it % curriculum_dump_freq == 0:
logger.save_pkl({"iteration": it,
**caches.slot_cache.get_summary(),
**caches.dist_cache.get_summary()},
path=f"curriculum/info.pkl", append=True)
if 'curriculum/distribution' in infos:
logger.save_pkl({"iteration": it,
"distribution": distribution},
path=f"curriculum/distribution.pkl", append=True)
mean_value_loss, mean_surrogate_loss, mean_adaptation_module_loss, mean_decoder_loss, mean_decoder_loss_student, mean_adaptation_module_test_loss, mean_decoder_test_loss, mean_decoder_test_loss_student = self.alg.update()
stop = time.time()
learn_time = stop - start
logger.store_metrics(
# total_time=learn_time - collection_time,
time_elapsed=logger.since('start'),
time_iter=logger.split('epoch'),
adaptation_loss=mean_adaptation_module_loss,
mean_value_loss=mean_value_loss,
mean_surrogate_loss=mean_surrogate_loss,
mean_decoder_loss=mean_decoder_loss,
mean_decoder_loss_student=mean_decoder_loss_student,
mean_decoder_test_loss=mean_decoder_test_loss,
mean_decoder_test_loss_student=mean_decoder_test_loss_student,
mean_adaptation_module_test_loss=mean_adaptation_module_test_loss
)
if RunnerArgs.save_video_interval:
self.log_video(it)
self.tot_timesteps += self.num_steps_per_env * self.env.num_envs
if logger.every(RunnerArgs.log_freq, "iteration", start_on=1):
# if it % Config.log_freq == 0:
logger.log_metrics_summary(key_values={"timesteps": self.tot_timesteps, "iterations": it})
logger.job_running()
if it % RunnerArgs.save_interval == 0:
with logger.Sync():
logger.torch_save(self.alg.actor_critic.state_dict(), f"checkpoints/ac_weights_{it:06d}.pt")
logger.duplicate(f"checkpoints/ac_weights_{it:06d}.pt", f"checkpoints/ac_weights_last.pt")
path = './tmp/legged_data'
os.makedirs(path, exist_ok=True)
adaptation_module_path = f'{path}/adaptation_module_latest.jit'
adaptation_module = copy.deepcopy(self.alg.actor_critic.adaptation_module).to('cpu')
traced_script_adaptation_module = torch.jit.script(adaptation_module)
traced_script_adaptation_module.save(adaptation_module_path)
body_path = f'{path}/body_latest.jit'
body_model = copy.deepcopy(self.alg.actor_critic.actor_body).to('cpu')
traced_script_body_module = torch.jit.script(body_model)
traced_script_body_module.save(body_path)
logger.upload_file(file_path=adaptation_module_path, target_path=f"checkpoints/", once=False)
logger.upload_file(file_path=body_path, target_path=f"checkpoints/", once=False)
self.current_learning_iteration += num_learning_iterations
with logger.Sync():
logger.torch_save(self.alg.actor_critic.state_dict(), f"checkpoints/ac_weights_{it:06d}.pt")
logger.duplicate(f"checkpoints/ac_weights_{it:06d}.pt", f"checkpoints/ac_weights_last.pt")
path = './tmp/legged_data'
os.makedirs(path, exist_ok=True)
adaptation_module_path = f'{path}/adaptation_module_latest.jit'
adaptation_module = copy.deepcopy(self.alg.actor_critic.adaptation_module).to('cpu')
traced_script_adaptation_module = torch.jit.script(adaptation_module)
traced_script_adaptation_module.save(adaptation_module_path)
body_path = f'{path}/body_latest.jit'
body_model = copy.deepcopy(self.alg.actor_critic.actor_body).to('cpu')
traced_script_body_module = torch.jit.script(body_model)
traced_script_body_module.save(body_path)
logger.upload_file(file_path=adaptation_module_path, target_path=f"checkpoints/", once=False)
logger.upload_file(file_path=body_path, target_path=f"checkpoints/", once=False)
def log_video(self, it):
if it - self.last_recording_it >= RunnerArgs.save_video_interval:
self.env.start_recording()
if self.env.num_eval_envs > 0:
self.env.start_recording_eval()
print("START RECORDING")
self.last_recording_it = it
frames = self.env.get_complete_frames()
if len(frames) > 0:
self.env.pause_recording()
print("LOGGING VIDEO")
logger.save_video(frames, f"videos/{it:05d}.mp4", fps=1 / self.env.dt)
if self.env.num_eval_envs > 0:
frames = self.env.get_complete_frames_eval()
if len(frames) > 0:
self.env.pause_recording_eval()
print("LOGGING EVAL VIDEO")
logger.save_video(frames, f"videos/{it:05d}_eval.mp4", fps=1 / self.env.dt)
def get_inference_policy(self, device=None):
self.alg.actor_critic.eval() # switch to evaluation mode (dropout for example)
if device is not None:
self.alg.actor_critic.to(device)
return self.alg.actor_critic.act_inference
def get_expert_policy(self, device=None):
self.alg.actor_critic.eval() # switch to evaluation mode (dropout for example)
if device is not None:
self.alg.actor_critic.to(device)
return self.alg.actor_critic.act_expert