From 68f58434955bcfd6fe1e72b225164cc8d075ccf7 Mon Sep 17 00:00:00 2001 From: Jerry Xu Date: Wed, 22 May 2024 16:35:46 -0400 Subject: [PATCH] velocity prediction added --- legged_gym/legged_gym/envs/a1/a1_leap_config.py | 3 ++- legged_gym/legged_gym/envs/base/legged_robot.py | 1 + legged_gym/legged_gym/scripts/play.py | 9 +++++---- rsl_rl/rsl_rl/modules/__init__.py | 2 ++ rsl_rl/rsl_rl/modules/actor_critic_recurrent.py | 15 +++++++++++++++ 5 files changed, 25 insertions(+), 5 deletions(-) diff --git a/legged_gym/legged_gym/envs/a1/a1_leap_config.py b/legged_gym/legged_gym/envs/a1/a1_leap_config.py index 2aec9a7..23b12a6 100644 --- a/legged_gym/legged_gym/envs/a1/a1_leap_config.py +++ b/legged_gym/legged_gym/envs/a1/a1_leap_config.py @@ -60,13 +60,14 @@ class A1LeapCfg( A1FieldCfg ): class rewards( A1FieldCfg.rewards ): class scales: tracking_ang_vel = 0.05 - world_vel_l2norm = -1. + # world_vel_l2norm = -1. legs_energy_substeps = -1e-6 alive = 2. penetrate_depth = -4e-3 penetrate_volume = -4e-3 exceed_dof_pos_limits = -1e-1 exceed_torque_limits_i = -2e-1 + # track_predict_vel_l2norm = -1. class curriculum( A1FieldCfg.curriculum ): penetrate_volume_threshold_harder = 9000 diff --git a/legged_gym/legged_gym/envs/base/legged_robot.py b/legged_gym/legged_gym/envs/base/legged_robot.py index 9df3134..488f38d 100644 --- a/legged_gym/legged_gym/envs/base/legged_robot.py +++ b/legged_gym/legged_gym/envs/base/legged_robot.py @@ -335,6 +335,7 @@ class LeggedRobot(BaseTask): Args: env_ids (List[int]): Environments ids for which new commands are needed """ + # print(self.command_ranges["lin_vel_x"][0], self.command_ranges["lin_vel_x"][1]) self.commands[env_ids, 0] = torch_rand_float(self.command_ranges["lin_vel_x"][0], self.command_ranges["lin_vel_x"][1], (len(env_ids), 1), device=self.device).squeeze(1) self.commands[env_ids, 1] = torch_rand_float(self.command_ranges["lin_vel_y"][0], self.command_ranges["lin_vel_y"][1], (len(env_ids), 1), device=self.device).squeeze(1) if self.cfg.commands.heading_command: diff --git a/legged_gym/legged_gym/scripts/play.py b/legged_gym/legged_gym/scripts/play.py index a11de64..605230c 100644 --- a/legged_gym/legged_gym/scripts/play.py +++ b/legged_gym/legged_gym/scripts/play.py @@ -86,7 +86,7 @@ def play(args): # override some parameters for testing if env_cfg.terrain.selected == "BarrierTrack": env_cfg.env.num_envs = min(env_cfg.env.num_envs, 1) - env_cfg.env.episode_length_s = 20 + env_cfg.env.episode_length_s = 5#20 env_cfg.terrain.max_init_terrain_level = 0 env_cfg.terrain.num_rows = 1 env_cfg.terrain.num_cols = 1 @@ -102,7 +102,7 @@ def play(args): env_cfg.terrain.BarrierTrack_kwargs["options"] = [ # "crawl", # "jump", - "leap", + # "leap", # "tilt", ] env_cfg.terrain.BarrierTrack_kwargs["leap"] = dict( @@ -113,8 +113,8 @@ def play(args): if "one_obstacle_per_track" in env_cfg.terrain.BarrierTrack_kwargs.keys(): env_cfg.terrain.BarrierTrack_kwargs.pop("one_obstacle_per_track") - env_cfg.terrain.BarrierTrack_kwargs["n_obstacles_per_track"] = 2# 2 - env_cfg.commands.ranges.lin_vel_x = [2.0, 2.0] # [1.2, 1.2] + env_cfg.terrain.BarrierTrack_kwargs["n_obstacles_per_track"] = 0# 2 + env_cfg.commands.ranges.lin_vel_x = [3.0, 3.0] # [1.2, 1.2] env_cfg.terrain.BarrierTrack_kwargs['track_block_length']= 3. if "distill" in args.task: env_cfg.commands.ranges.lin_vel_x = [0.0, 0.0] @@ -125,6 +125,7 @@ def play(args): x= [0.6, 0.6], y= [-0.05, 0.05], ) + env_cfg.commands.curriculum = False env_cfg.termination.termination_terms = [] env_cfg.termination.timeout_at_border = False env_cfg.termination.timeout_at_finished = False diff --git a/rsl_rl/rsl_rl/modules/__init__.py b/rsl_rl/rsl_rl/modules/__init__.py index 3b21295..96cfc37 100644 --- a/rsl_rl/rsl_rl/modules/__init__.py +++ b/rsl_rl/rsl_rl/modules/__init__.py @@ -53,6 +53,8 @@ def build_actor_critic(env, policy_class_name, policy_cfg): policy_cfg["num_critic_obs"] = num_critic_obs if not "num_actions" in policy_cfg: policy_cfg["num_actions"] = env.num_actions + if not "lin_vel_x" in policy_cfg: + policy_cfg["lin_vel_x"] = env.command_ranges["lin_vel_x"] actor_critic: ActorCritic = actor_critic_class(**policy_cfg) diff --git a/rsl_rl/rsl_rl/modules/actor_critic_recurrent.py b/rsl_rl/rsl_rl/modules/actor_critic_recurrent.py index d84506d..2757146 100644 --- a/rsl_rl/rsl_rl/modules/actor_critic_recurrent.py +++ b/rsl_rl/rsl_rl/modules/actor_critic_recurrent.py @@ -62,6 +62,16 @@ class ActorCriticRecurrent(ActorCritic): activation = get_activation(activation) + self.velocity_planner = nn.Sequential( + nn.Linear(num_actor_obs-3, 256), + nn.ELU(), + nn.Linear(256, 128), + nn.ELU(), + nn.Linear(128, 1), + nn.Tanh() + ) + + self.lin_vel_x = kwargs["lin_vel_x"] self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size) self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size) @@ -73,6 +83,11 @@ class ActorCriticRecurrent(ActorCritic): self.memory_c.reset(dones) def act(self, observations, masks=None, hidden_states=None): + vel_obs = torch.cat([observations[:, :9], observations[:, 12:]], dim=1) + velocity = self.velocity_planner(vel_obs) + velocity = torch.clip(velocity, self.lin_vel_x[0], self.lin_vel_x[1]) + self.velocity = velocity + observations[:, 9] = velocity input_a = self.memory_a(observations, masks, hidden_states) return super().act(input_a.squeeze(0))