From edde9b8f8a4325b3a6cb882305f02a01ffba2d7c Mon Sep 17 00:00:00 2001 From: Jerry Xu Date: Sun, 2 Jun 2024 16:05:57 -0400 Subject: [PATCH] velocity in ppo --- .../legged_gym/envs/a1/a1_leap_config.py | 7 +++++-- rsl_rl/rsl_rl/algorithms/ppo.py | 17 ++++++++++------- rsl_rl/rsl_rl/runners/on_policy_runner.py | 19 +++++++++++++++---- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/legged_gym/legged_gym/envs/a1/a1_leap_config.py b/legged_gym/legged_gym/envs/a1/a1_leap_config.py index ae6bd5b..f7e8532 100644 --- a/legged_gym/legged_gym/envs/a1/a1_leap_config.py +++ b/legged_gym/legged_gym/envs/a1/a1_leap_config.py @@ -30,7 +30,6 @@ class A1LeapCfg( A1FieldCfg ): virtual_terrain= True, # Change this to False for real terrain no_perlin_threshold= 0.06, n_obstacles_curriculum = True, - n_obstacles_per_track=2, )) TerrainPerlin_kwargs = merge_dict(A1FieldCfg.terrain.TerrainPerlin_kwargs, dict( @@ -82,6 +81,9 @@ class A1LeapCfg( A1FieldCfg ): class A1LeapCfgPPO( A1FieldCfgPPO ): + class policy(A1FieldCfgPPO.policy): + num_critic_obs = 78 + class algorithm( A1FieldCfgPPO.algorithm ): entropy_coef = 0.0 clip_min_std = 0.2 @@ -104,7 +106,8 @@ class A1LeapCfgPPO( A1FieldCfgPPO ): resume = True # load_run = "{Your traind walking model directory}" # load_run = "May16_18-12-08_WalkingBase_pEnergySubsteps2e-5_aScale0.5" - load_run = "Leap_loss_success" + # load_run = "Leap_loss_success" + load_run = 'High_speed_walk' # load_run = "May15_21-34-27_Skillleap_pEnergySubsteps-1e-06_virtual"#"May15_17-07-38_WalkingBase_pEnergySubsteps2e-5_aScale0.5" # load_run = "{Your virtually trained leap model directory}" max_iterations = 20000 diff --git a/rsl_rl/rsl_rl/algorithms/ppo.py b/rsl_rl/rsl_rl/algorithms/ppo.py index aa070a2..ddb3592 100644 --- a/rsl_rl/rsl_rl/algorithms/ppo.py +++ b/rsl_rl/rsl_rl/algorithms/ppo.py @@ -115,14 +115,14 @@ class PPO: # if self.lin_vel_x is not None: # velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1]) self.transition.actions = self.actor_critic.act(obs, velocity=velocity * self.command_scale)[0].detach() - critic_obs[..., 9] = velocity.squeeze(-1) * self.command_scale - self.transition.values = self.actor_critic.evaluate(critic_obs).detach() + # critic_obs[..., 9] = velocity.squeeze(-1) * self.command_scale + self.transition.values = self.actor_critic.evaluate(vel_obs).detach() self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach() self.transition.action_mean = self.actor_critic.action_mean.detach() self.transition.action_sigma = self.actor_critic.action_std.detach() # need to record obs and critic_obs before env.step() - self.transition.observations = obs - self.transition.critic_observations = critic_obs + self.transition.observations = vel_obs + self.transition.critic_observations = vel_obs return self.transition.actions, velocity.squeeze().detach() def process_env_step(self, rewards, dones, infos): @@ -138,6 +138,7 @@ class PPO: self.actor_critic.reset(dones) def compute_returns(self, last_critic_obs): + last_critic_obs = torch.cat([last_critic_obs[..., :9], last_critic_obs[..., 12:]], dim=-1) last_values= self.actor_critic.evaluate(last_critic_obs).detach() self.storage.compute_returns(last_values, self.gamma, self.lam) @@ -183,12 +184,14 @@ class PPO: def compute_losses(self, minibatch, current_learning_iteration=None): obs = copy.deepcopy(minibatch.obs) - vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1) + # vel_obs = torch.cat([obs[..., :9], obs[..., 12:]], dim=-1) - velocity = self.velocity_planner(vel_obs) + velocity = self.velocity_planner(obs) + zeros = torch.zeros_like(velocity, requires_grad=False) + obs_with_cmd = torch.cat([obs[..., :9], velocity, zeros, zeros, obs[..., 9:]], dim=-1) # if self.lin_vel_x is not None: # velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1]) - self.actor_critic.act(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0], velocity=velocity * self.command_scale) + self.actor_critic.act(obs_with_cmd, masks=minibatch.masks, hidden_states=minibatch.hid_states[0], velocity=velocity * self.command_scale) actions_log_prob_batch = self.actor_critic.get_actions_log_prob(minibatch.actions) value_batch = self.actor_critic.evaluate(obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[1]) mu_batch = self.actor_critic.action_mean diff --git a/rsl_rl/rsl_rl/runners/on_policy_runner.py b/rsl_rl/rsl_rl/runners/on_policy_runner.py index 99c6ba4..e568ebf 100644 --- a/rsl_rl/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/rsl_rl/runners/on_policy_runner.py @@ -78,7 +78,7 @@ class OnPolicyRunner: self.save_interval = self.cfg["save_interval"] # init storage and model - self.alg.init_storage(self.env.num_envs, self.num_steps_per_env, [self.env.num_obs], [self.env.num_privileged_obs], [self.env.num_actions]) + self.alg.init_storage(self.env.num_envs, self.num_steps_per_env, [self.env.num_obs - 3], [self.env.num_privileged_obs], [self.env.num_actions]) # Log self.log_dir = log_dir @@ -251,11 +251,22 @@ class OnPolicyRunner: def load(self, path, load_optimizer=True): loaded_dict = torch.load(path) - self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict']) + + try: + self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'], strict=False) + if load_optimizer and "optimizer_state_dict" in loaded_dict: + self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'], ) + except: + from collections import OrderedDict + new_state = OrderedDict() + for k, v in loaded_dict['model_state_dict'].items(): + if "memory_c" in k or "critic" in k: + continue + new_state[k] = v + self.alg.actor_critic.load_state_dict(new_state, strict=False) + if 'velocity_planner_state_dict' in loaded_dict: self.alg.velocity_planner.load_state_dict(loaded_dict['velocity_planner_state_dict']) - if load_optimizer and "optimizer_state_dict" in loaded_dict: - self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'], ) if load_optimizer and "velocity_optimizer_state_dict" in loaded_dict: self.alg.velocity_optimizer.load_state_dict(loaded_dict['velocity_optimizer_state_dict'], ) if "lr_scheduler_state_dict" in loaded_dict: