[fix] typo and bugs in loading demonstration for Dagger
* Fix the problem of accessing traj_data before refreshing trajectory handlers * Also, fix motor_strength when not using LeggedRobotNoisyMixin * Also, fix barrier_track typo
This commit is contained in:
parent
1ffd6d7c05
commit
637080edaf
|
@ -700,6 +700,7 @@ class LeggedRobot(BaseTask):
|
|||
Returns:
|
||||
[torch.Tensor]: Torques sent to the simulation
|
||||
"""
|
||||
actions = self.motor_strength * actions
|
||||
#pd controller
|
||||
if isinstance(self.cfg.control.action_scale, (tuple, list)):
|
||||
self.cfg.control.action_scale = torch.tensor(self.cfg.control.action_scale, device= self.sim_device)
|
||||
|
|
|
@ -103,6 +103,7 @@ def play(args):
|
|||
"wave",
|
||||
]
|
||||
env_cfg.terrain.BarrierTrack_kwargs["leap"]["fake_offset"] = 0.1
|
||||
env_cfg.terrain.BarrierTrack_kwargs["draw_virtual_terrain"] = True
|
||||
else:
|
||||
env_cfg.env.num_envs = min(env_cfg.env.num_envs, 1)
|
||||
env_cfg.env.episode_length_s = 60
|
||||
|
@ -131,7 +132,7 @@ def play(args):
|
|||
env_cfg.viewer.draw_sensors = False
|
||||
if hasattr(env_cfg.terrain, "BarrierTrack_kwargs"):
|
||||
env_cfg.terrain.BarrierTrack_kwargs["draw_virtual_terrain"] = True
|
||||
# train_cfg.runner.resume = (args.load_run is not None)
|
||||
train_cfg.runner.resume = (args.load_run is not None)
|
||||
train_cfg.runner_class_name = "OnPolicyRunner"
|
||||
|
||||
if args.no_throw:
|
||||
|
@ -196,6 +197,8 @@ def play(args):
|
|||
)
|
||||
agent_model = ppo_runner.alg.actor_critic
|
||||
policy = ppo_runner.get_inference_policy(device=env.device)
|
||||
if args.sample:
|
||||
policy = agent_model.act
|
||||
### get obs_slice to read the obs
|
||||
# obs_slice = get_obs_slice(env.obs_segments, "engaging_block")
|
||||
|
||||
|
@ -205,7 +208,6 @@ def play(args):
|
|||
export_policy_as_jit(ppo_runner.alg.actor_critic, path)
|
||||
print('Exported policy as jit script to: ', path)
|
||||
if RECORD_FRAMES:
|
||||
os.mkdir(os.path.join(LEGGED_GYM_ROOT_DIR, "logs", "images"), exist_ok= True)
|
||||
transform = gymapi.Transform()
|
||||
transform.p = gymapi.Vec3(*env_cfg.viewer.pos)
|
||||
transform.r = gymapi.Quat.from_euler_zyx(0., 0., -np.pi/2)
|
||||
|
@ -214,6 +216,8 @@ def play(args):
|
|||
env.envs[0],
|
||||
transform= transform,
|
||||
)
|
||||
if not os.path.exists(os.path.join(LEGGED_GYM_ROOT_DIR, "logs", args.frames_dir)):
|
||||
os.makedirs(os.path.join(LEGGED_GYM_ROOT_DIR, "logs", args.frames_dir))
|
||||
|
||||
logger = Logger(env.dt)
|
||||
robot_index = 0 # which robot is used for logging
|
||||
|
|
|
@ -1515,7 +1515,7 @@ class BarrierTrack:
|
|||
block_idx = torch.floor(forward_distance / self.env_block_length).to(int) # (n,)
|
||||
block_idx_clipped = torch.clip(
|
||||
block_idx,
|
||||
0.,
|
||||
0,
|
||||
(self.n_blocks_per_track - 1),
|
||||
)
|
||||
|
||||
|
@ -1680,7 +1680,7 @@ class BarrierTrack:
|
|||
dim= -1,
|
||||
)
|
||||
if mask_only:
|
||||
return distance_to_edge < self.track_kwargs["stairsup"].get("residual_distance", 0.05)
|
||||
return (distance_to_edge < self.track_kwargs["stairsup"].get("residual_distance", 0.05)).to(torch.float32)
|
||||
else:
|
||||
return torch.clip(
|
||||
self.track_kwargs["stairsup"].get("residual_distance", 0.05) - distance_to_edge,
|
||||
|
|
|
@ -126,6 +126,8 @@ class RolloutDataset(RolloutFileBase):
|
|||
self.traj_datas = [None for _ in range(self.num_envs)]
|
||||
self.traj_cursors = np.zeros(self.num_envs, dtype= int)
|
||||
|
||||
self.refresh_handlers()
|
||||
|
||||
def _refresh_traj_data(self, env_idx):
|
||||
""" refresh `self.traj_data` based on current traj_file_idxs[env_idx]. usually called
|
||||
after refreshing traj_handler or updated traj_file_idxs[env_idx]
|
||||
|
@ -225,7 +227,7 @@ class RolloutDataset(RolloutFileBase):
|
|||
)
|
||||
dones = torch.empty(
|
||||
leading_dims,
|
||||
dtype= self.traj_datas[0]["dones"].dtype,
|
||||
dtype= bool,
|
||||
device= self.device,
|
||||
)
|
||||
timeouts = torch.empty(
|
||||
|
|
Loading…
Reference in New Issue