Copied test file to extend and implement training

This commit is contained in:
Nimesh Khandelwal 2024-02-15 21:31:32 -05:00
parent 0bd527f72b
commit fae38c6eb9
1 changed files with 5 additions and 50 deletions

View File

@ -5,8 +5,6 @@ import matplotlib.pyplot as plt
import torch import torch
from tqdm import trange from tqdm import trange
from ml_logger import logger
from Go2Py.sim.gym.envs import * from Go2Py.sim.gym.envs import *
from Go2Py.sim.gym.envs.base.legged_robot_config import Cfg from Go2Py.sim.gym.envs.base.legged_robot_config import Cfg
from Go2Py.sim.gym.envs.go2.go2_config import config_go2 from Go2Py.sim.gym.envs.go2.go2_config import config_go2
@ -197,55 +195,12 @@ def run_env(render=False, headless=False):
print("Show the first frame and exit.") print("Show the first frame and exit.")
exit() exit()
# log the experiment parameters for i in trange(1000, desc="Running"):
logger.log_params(AC_Args=vars(AC_Args), PPO_Args=vars(PPO_Args), RunnerArgs=vars(RunnerArgs), actions = 0. * torch.ones(env.num_envs, env.num_actions, device=env.device)
Cfg=vars(Cfg)) obs, rew, done, info = env.step(actions)
# breakpoint()
env = HistoryWrapper(env) print("Done")
gpu_id = 2
runner = Runner(env, device=f"cuda:{gpu_id}")
runner.learn(num_learning_iterations=10000, init_at_random_ep_len=True, eval_freq=100)
# for i in trange(1000, desc="Running"):
# actions = 0. * torch.ones(env.num_envs, env.num_actions, device=env.device)
# obs, rew, done, info = env.step(actions)
# # breakpoint()
# print("Done")
if __name__ == '__main__': if __name__ == '__main__':
from pathlib import Path
from ml_logger import logger
from go2_gym import MINI_GYM_ROOT_DIR
stem = Path(__file__).stem
logger.configure(logger.utcnow(f'gait-conditioned-agility/%Y-%m-%d/{stem}/%H%M%S.%f'),
root=Path(f"{MINI_GYM_ROOT_DIR}/runs").resolve(), )
logger.log_text("""
charts:
- yKey: train/episode/rew_total/mean
xKey: iterations
- yKey: train/episode/rew_tracking_lin_vel/mean
xKey: iterations
- yKey: train/episode/rew_tracking_contacts_shaped_force/mean
xKey: iterations
- yKey: train/episode/rew_action_smoothness_1/mean
xKey: iterations
- yKey: train/episode/rew_action_smoothness_2/mean
xKey: iterations
- yKey: train/episode/rew_tracking_contacts_shaped_vel/mean
xKey: iterations
- yKey: train/episode/rew_orientation_control/mean
xKey: iterations
- yKey: train/episode/rew_dof_pos/mean
xKey: iterations
- yKey: train/episode/command_area_trot/mean
xKey: iterations
- yKey: train/episode/max_terrain_height/mean
xKey: iterations
- type: video
glob: "videos/*.mp4"
- yKey: adaptation_loss/mean
xKey: iterations
""", filename=".charts.yml", dedent=True)
run_env(render=True, headless=False) run_env(render=True, headless=False)