[fix] typo and bugs in loading demonstration for Dagger

* Fix the problem of accessing traj_data before refreshing trajectory handlers
* Also, fix motor_strength when not using LeggedRobotNoisyMixin
* Also, fix barrier_track typo
This commit is contained in:
Ziwen Zhuang 2024-09-07 00:16:03 +08:00
parent 1ffd6d7c05
commit 637080edaf
4 changed files with 12 additions and 5 deletions

View File

@ -700,6 +700,7 @@ class LeggedRobot(BaseTask):
Returns: Returns:
[torch.Tensor]: Torques sent to the simulation [torch.Tensor]: Torques sent to the simulation
""" """
actions = self.motor_strength * actions
#pd controller #pd controller
if isinstance(self.cfg.control.action_scale, (tuple, list)): if isinstance(self.cfg.control.action_scale, (tuple, list)):
self.cfg.control.action_scale = torch.tensor(self.cfg.control.action_scale, device= self.sim_device) self.cfg.control.action_scale = torch.tensor(self.cfg.control.action_scale, device= self.sim_device)

View File

@ -103,6 +103,7 @@ def play(args):
"wave", "wave",
] ]
env_cfg.terrain.BarrierTrack_kwargs["leap"]["fake_offset"] = 0.1 env_cfg.terrain.BarrierTrack_kwargs["leap"]["fake_offset"] = 0.1
env_cfg.terrain.BarrierTrack_kwargs["draw_virtual_terrain"] = True
else: else:
env_cfg.env.num_envs = min(env_cfg.env.num_envs, 1) env_cfg.env.num_envs = min(env_cfg.env.num_envs, 1)
env_cfg.env.episode_length_s = 60 env_cfg.env.episode_length_s = 60
@ -131,7 +132,7 @@ def play(args):
env_cfg.viewer.draw_sensors = False env_cfg.viewer.draw_sensors = False
if hasattr(env_cfg.terrain, "BarrierTrack_kwargs"): if hasattr(env_cfg.terrain, "BarrierTrack_kwargs"):
env_cfg.terrain.BarrierTrack_kwargs["draw_virtual_terrain"] = True env_cfg.terrain.BarrierTrack_kwargs["draw_virtual_terrain"] = True
# train_cfg.runner.resume = (args.load_run is not None) train_cfg.runner.resume = (args.load_run is not None)
train_cfg.runner_class_name = "OnPolicyRunner" train_cfg.runner_class_name = "OnPolicyRunner"
if args.no_throw: if args.no_throw:
@ -196,6 +197,8 @@ def play(args):
) )
agent_model = ppo_runner.alg.actor_critic agent_model = ppo_runner.alg.actor_critic
policy = ppo_runner.get_inference_policy(device=env.device) policy = ppo_runner.get_inference_policy(device=env.device)
if args.sample:
policy = agent_model.act
### get obs_slice to read the obs ### get obs_slice to read the obs
# obs_slice = get_obs_slice(env.obs_segments, "engaging_block") # obs_slice = get_obs_slice(env.obs_segments, "engaging_block")
@ -205,7 +208,6 @@ def play(args):
export_policy_as_jit(ppo_runner.alg.actor_critic, path) export_policy_as_jit(ppo_runner.alg.actor_critic, path)
print('Exported policy as jit script to: ', path) print('Exported policy as jit script to: ', path)
if RECORD_FRAMES: if RECORD_FRAMES:
os.mkdir(os.path.join(LEGGED_GYM_ROOT_DIR, "logs", "images"), exist_ok= True)
transform = gymapi.Transform() transform = gymapi.Transform()
transform.p = gymapi.Vec3(*env_cfg.viewer.pos) transform.p = gymapi.Vec3(*env_cfg.viewer.pos)
transform.r = gymapi.Quat.from_euler_zyx(0., 0., -np.pi/2) transform.r = gymapi.Quat.from_euler_zyx(0., 0., -np.pi/2)
@ -214,6 +216,8 @@ def play(args):
env.envs[0], env.envs[0],
transform= transform, transform= transform,
) )
if not os.path.exists(os.path.join(LEGGED_GYM_ROOT_DIR, "logs", args.frames_dir)):
os.makedirs(os.path.join(LEGGED_GYM_ROOT_DIR, "logs", args.frames_dir))
logger = Logger(env.dt) logger = Logger(env.dt)
robot_index = 0 # which robot is used for logging robot_index = 0 # which robot is used for logging

View File

@ -1515,7 +1515,7 @@ class BarrierTrack:
block_idx = torch.floor(forward_distance / self.env_block_length).to(int) # (n,) block_idx = torch.floor(forward_distance / self.env_block_length).to(int) # (n,)
block_idx_clipped = torch.clip( block_idx_clipped = torch.clip(
block_idx, block_idx,
0., 0,
(self.n_blocks_per_track - 1), (self.n_blocks_per_track - 1),
) )
@ -1680,7 +1680,7 @@ class BarrierTrack:
dim= -1, dim= -1,
) )
if mask_only: if mask_only:
return distance_to_edge < self.track_kwargs["stairsup"].get("residual_distance", 0.05) return (distance_to_edge < self.track_kwargs["stairsup"].get("residual_distance", 0.05)).to(torch.float32)
else: else:
return torch.clip( return torch.clip(
self.track_kwargs["stairsup"].get("residual_distance", 0.05) - distance_to_edge, self.track_kwargs["stairsup"].get("residual_distance", 0.05) - distance_to_edge,

View File

@ -126,6 +126,8 @@ class RolloutDataset(RolloutFileBase):
self.traj_datas = [None for _ in range(self.num_envs)] self.traj_datas = [None for _ in range(self.num_envs)]
self.traj_cursors = np.zeros(self.num_envs, dtype= int) self.traj_cursors = np.zeros(self.num_envs, dtype= int)
self.refresh_handlers()
def _refresh_traj_data(self, env_idx): def _refresh_traj_data(self, env_idx):
""" refresh `self.traj_data` based on current traj_file_idxs[env_idx]. usually called """ refresh `self.traj_data` based on current traj_file_idxs[env_idx]. usually called
after refreshing traj_handler or updated traj_file_idxs[env_idx] after refreshing traj_handler or updated traj_file_idxs[env_idx]
@ -225,7 +227,7 @@ class RolloutDataset(RolloutFileBase):
) )
dones = torch.empty( dones = torch.empty(
leading_dims, leading_dims,
dtype= self.traj_datas[0]["dones"].dtype, dtype= bool,
device= self.device, device= self.device,
) )
timeouts = torch.empty( timeouts = torch.empty(