diff --git a/legged_gym/legged_gym/envs/base/legged_robot.py b/legged_gym/legged_gym/envs/base/legged_robot.py index 8cffda3..cbbcbbb 100644 --- a/legged_gym/legged_gym/envs/base/legged_robot.py +++ b/legged_gym/legged_gym/envs/base/legged_robot.py @@ -700,6 +700,7 @@ class LeggedRobot(BaseTask): Returns: [torch.Tensor]: Torques sent to the simulation """ + actions = self.motor_strength * actions #pd controller if isinstance(self.cfg.control.action_scale, (tuple, list)): self.cfg.control.action_scale = torch.tensor(self.cfg.control.action_scale, device= self.sim_device) diff --git a/legged_gym/legged_gym/scripts/play.py b/legged_gym/legged_gym/scripts/play.py index ceed8d1..491c504 100644 --- a/legged_gym/legged_gym/scripts/play.py +++ b/legged_gym/legged_gym/scripts/play.py @@ -103,6 +103,7 @@ def play(args): "wave", ] env_cfg.terrain.BarrierTrack_kwargs["leap"]["fake_offset"] = 0.1 + env_cfg.terrain.BarrierTrack_kwargs["draw_virtual_terrain"] = True else: env_cfg.env.num_envs = min(env_cfg.env.num_envs, 1) env_cfg.env.episode_length_s = 60 @@ -131,7 +132,7 @@ def play(args): env_cfg.viewer.draw_sensors = False if hasattr(env_cfg.terrain, "BarrierTrack_kwargs"): env_cfg.terrain.BarrierTrack_kwargs["draw_virtual_terrain"] = True - # train_cfg.runner.resume = (args.load_run is not None) + train_cfg.runner.resume = (args.load_run is not None) train_cfg.runner_class_name = "OnPolicyRunner" if args.no_throw: @@ -196,6 +197,8 @@ def play(args): ) agent_model = ppo_runner.alg.actor_critic policy = ppo_runner.get_inference_policy(device=env.device) + if args.sample: + policy = agent_model.act ### get obs_slice to read the obs # obs_slice = get_obs_slice(env.obs_segments, "engaging_block") @@ -205,7 +208,6 @@ def play(args): export_policy_as_jit(ppo_runner.alg.actor_critic, path) print('Exported policy as jit script to: ', path) if RECORD_FRAMES: - os.mkdir(os.path.join(LEGGED_GYM_ROOT_DIR, "logs", "images"), exist_ok= True) transform = gymapi.Transform() transform.p = gymapi.Vec3(*env_cfg.viewer.pos) transform.r = gymapi.Quat.from_euler_zyx(0., 0., -np.pi/2) @@ -214,6 +216,8 @@ def play(args): env.envs[0], transform= transform, ) + if not os.path.exists(os.path.join(LEGGED_GYM_ROOT_DIR, "logs", args.frames_dir)): + os.makedirs(os.path.join(LEGGED_GYM_ROOT_DIR, "logs", args.frames_dir)) logger = Logger(env.dt) robot_index = 0 # which robot is used for logging diff --git a/legged_gym/legged_gym/utils/terrain/barrier_track.py b/legged_gym/legged_gym/utils/terrain/barrier_track.py index 6f31b0b..660748c 100644 --- a/legged_gym/legged_gym/utils/terrain/barrier_track.py +++ b/legged_gym/legged_gym/utils/terrain/barrier_track.py @@ -1515,7 +1515,7 @@ class BarrierTrack: block_idx = torch.floor(forward_distance / self.env_block_length).to(int) # (n,) block_idx_clipped = torch.clip( block_idx, - 0., + 0, (self.n_blocks_per_track - 1), ) @@ -1680,7 +1680,7 @@ class BarrierTrack: dim= -1, ) if mask_only: - return distance_to_edge < self.track_kwargs["stairsup"].get("residual_distance", 0.05) + return (distance_to_edge < self.track_kwargs["stairsup"].get("residual_distance", 0.05)).to(torch.float32) else: return torch.clip( self.track_kwargs["stairsup"].get("residual_distance", 0.05) - distance_to_edge, diff --git a/rsl_rl/rsl_rl/storage/rollout_files/rollout_dataset.py b/rsl_rl/rsl_rl/storage/rollout_files/rollout_dataset.py index 1a16ddf..5a411c8 100644 --- a/rsl_rl/rsl_rl/storage/rollout_files/rollout_dataset.py +++ b/rsl_rl/rsl_rl/storage/rollout_files/rollout_dataset.py @@ -126,6 +126,8 @@ class RolloutDataset(RolloutFileBase): self.traj_datas = [None for _ in range(self.num_envs)] self.traj_cursors = np.zeros(self.num_envs, dtype= int) + self.refresh_handlers() + def _refresh_traj_data(self, env_idx): """ refresh `self.traj_data` based on current traj_file_idxs[env_idx]. usually called after refreshing traj_handler or updated traj_file_idxs[env_idx] @@ -225,7 +227,7 @@ class RolloutDataset(RolloutFileBase): ) dones = torch.empty( leading_dims, - dtype= self.traj_datas[0]["dones"].dtype, + dtype= bool, device= self.device, ) timeouts = torch.empty(