321 lines
15 KiB
Python
321 lines
15 KiB
Python
import os
|
|
import os.path as osp
|
|
import json
|
|
import pickle
|
|
import time
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from rsl_rl.utils.utils import get_obs_slice
|
|
import rsl_rl.utils.data_compresser as compresser
|
|
from rsl_rl.storage.rollout_storage import RolloutStorage
|
|
|
|
class DemonstrationSaver:
|
|
def __init__(self,
|
|
env,
|
|
policy, # any object with "act(obs, critic_obs)" method to get actions and "get_hidden_states()" method to get hidden states
|
|
save_dir,
|
|
rollout_storage_length= 64,
|
|
min_timesteps= 1e6,
|
|
min_episodes= 10000,
|
|
success_traj_only = False, # if true, the trajectory terminated no by timeout will be dumped.
|
|
use_critic_obs= False,
|
|
obs_disassemble_mapping= None,
|
|
demo_by_sample= False,
|
|
):
|
|
"""
|
|
Args:
|
|
obs_disassemble_mapping (dict): If set, the obs segment will be compressed using given
|
|
type. example: {"forward_depth": "normalized_image", "forward_rgb": "normalized_image"}
|
|
demo_by_sample (bool): # if True, the action will be sampled (policy.act) from the
|
|
policy instead of using the mean (policy.act_inference).
|
|
"""
|
|
self.env = env
|
|
self.policy = policy
|
|
|
|
self.save_dir = save_dir
|
|
self.rollout_storage_length = rollout_storage_length
|
|
self.min_timesteps = min_timesteps
|
|
self.min_episodes = min_episodes
|
|
self.use_critic_obs = use_critic_obs
|
|
self.success_traj_only = success_traj_only
|
|
self.obs_disassemble_mapping = obs_disassemble_mapping
|
|
self.demo_by_sample = demo_by_sample
|
|
self.RolloutStorageCls = RolloutStorage
|
|
|
|
def init_traj_handlers(self):
|
|
# check if data exists, continue
|
|
if len(os.listdir(self.save_dir)) > 1:
|
|
print("Continuing from previous data. You have to make sure the environment configuration is the same.")
|
|
prev_traj = [x for x in os.listdir(self.save_dir) if x.startswith("trajectory_")]
|
|
prev_traj.sort(key= lambda x: int(x.split("_")[1]))
|
|
# fill up the traj_idxs
|
|
self.traj_idxs = []
|
|
for f in prev_traj:
|
|
if len(os.listdir(osp.join(self.save_dir, f))) == 0:
|
|
self.traj_idxs.append(int(f.split("_")[1]))
|
|
if len(self.traj_idxs) < self.env.num_envs:
|
|
max_traj_idx = max(self.traj_idxs) if len(self.traj_idxs) > 0 else int(prev_traj[-1].split("_")[1])
|
|
for _ in range(self.env.num_envs - len(self.traj_idxs)):
|
|
self.traj_idxs.append(max_traj_idx + 1)
|
|
max_traj_idx += 1
|
|
self.traj_idxs = np.array(self.traj_idxs[:self.env.num_envs])
|
|
# load the dataset statistics
|
|
with open(osp.join(self.save_dir, "metadata.json"), "r") as f:
|
|
metadata = json.load(f)
|
|
self.total_traj_completed = metadata["total_trajectories"]
|
|
self.total_timesteps = metadata["total_timesteps"]
|
|
else:
|
|
self.traj_idxs = np.arange(self.env.num_envs)
|
|
self.total_traj_completed = 0
|
|
self.total_timesteps = 0
|
|
self.metadata["total_timesteps"] = self.total_timesteps
|
|
self.metadata["total_trajectories"] = self.total_traj_completed
|
|
for traj_idx in self.traj_idxs:
|
|
os.makedirs(osp.join(self.save_dir, f"trajectory_{traj_idx}"), exist_ok= True)
|
|
self.dumped_traj_lengths = np.zeros(self.env.num_envs, dtype= np.int32)
|
|
|
|
# initialize compressing parameters if needed
|
|
if not self.obs_disassemble_mapping is None:
|
|
self.metadata["obs_segments"] = self.env.obs_segments
|
|
self.metadata["obs_disassemble_mapping"] = self.obs_disassemble_mapping
|
|
|
|
def init_storage_buffer(self):
|
|
self.rollout_storage = self.RolloutStorageCls(
|
|
self.env.num_envs,
|
|
self.rollout_storage_length,
|
|
[self.env.num_obs],
|
|
[self.env.num_privileged_obs],
|
|
[self.env.num_actions],
|
|
self.env.device,
|
|
)
|
|
self.transition = self.RolloutStorageCls.Transition()
|
|
self.transition_has_timeouts = False
|
|
self.transition_timeouts = torch.zeros(self.rollout_storage_length, self.env.num_envs, dtype= torch.bool, device= self.env.device)
|
|
|
|
def check_stop(self):
|
|
return (self.total_traj_completed >= self.min_episodes) \
|
|
and (self.total_timesteps >= self.min_timesteps)
|
|
|
|
@torch.no_grad()
|
|
def collect_step(self, step_i):
|
|
""" Collect one step of demonstration data """
|
|
actions, rewards, dones, infos, n_obs, n_critic_obs = self.get_transition()
|
|
|
|
self.build_transition(step_i, actions, rewards, dones, infos)
|
|
self.add_transition(step_i, infos)
|
|
self.transition.clear()
|
|
|
|
self.policy_reset(dones)
|
|
self.obs, self.critic_obs = n_obs, n_critic_obs
|
|
|
|
def get_policy_actions(self):
|
|
if self.use_critic_obs and self.demo_by_sample:
|
|
actions = self.policy.act(self.critic_obs)
|
|
elif self.use_critic_obs:
|
|
actions = self.policy.act_inference(self.critic_obs)
|
|
elif self.demo_by_sample:
|
|
actions = self.policy.act(self.obs)
|
|
else:
|
|
actions = self.policy.act_inference(self.obs)
|
|
return actions
|
|
|
|
def get_transition(self):
|
|
actions = self.get_policy_actions()
|
|
n_obs, n_critic_obs, rewards, dones, infos = self.env.step(actions)
|
|
return actions, rewards, dones, infos, n_obs, n_critic_obs
|
|
|
|
def build_transition(self, step_i, actions, rewards, dones, infos):
|
|
""" Fill the transition to meet the interface of rollout storage """
|
|
self.transition.observations = self.obs
|
|
if not self.critic_obs is None: self.transition.critic_observations = self.critic_obs
|
|
# if self.policy.is_recurrent:
|
|
# self.transition.hidden_states = self.policy.get_hidden_states()
|
|
self.transition.actions = actions
|
|
self.transition.rewards = rewards
|
|
self.transition.dones = dones
|
|
|
|
# fill up some of the attributes to meet the interface of rollout storage, but not collected to files
|
|
self.transition.values = torch.zeros_like(rewards).unsqueeze(-1)
|
|
self.transition.actions_log_prob = torch.zeros_like(rewards)
|
|
self.transition.action_mean = torch.zeros_like(actions)
|
|
self.transition.action_sigma = torch.zeros_like(actions)
|
|
|
|
def add_transition(self, step_i, infos):
|
|
self.rollout_storage.add_transitions(self.transition)
|
|
if "time_outs" in infos:
|
|
self.transition_has_timeouts = True
|
|
self.transition_timeouts[step_i] = infos["time_outs"]
|
|
|
|
def policy_reset(self, dones):
|
|
if dones.any():
|
|
self.policy.reset(dones)
|
|
|
|
def dump_to_file(self, env_i, step_slice):
|
|
""" dump the part of trajectory to the trajectory directory """
|
|
traj_idx = self.traj_idxs[env_i]
|
|
traj_dir = osp.join(self.save_dir, f"trajectory_{traj_idx}")
|
|
traj_file = osp.join(
|
|
traj_dir,
|
|
f"traj_{self.dumped_traj_lengths[env_i]:06d}_{self.dumped_traj_lengths[env_i]+step_slice.stop-step_slice.start:06d}.pickle",
|
|
)
|
|
trajectory = self.wrap_up_trajectory(env_i, step_slice)
|
|
with open(traj_file, 'wb') as f:
|
|
pickle.dump(trajectory, f)
|
|
self.dumped_traj_lengths[env_i] += step_slice.stop - step_slice.start
|
|
self.total_timesteps += step_slice.stop - step_slice.start
|
|
|
|
def dump_metadata(self):
|
|
self.metadata["total_timesteps"] = self.total_timesteps.item() if isinstance(self.total_timesteps, np.int64) else self.total_timesteps
|
|
self.metadata["total_trajectories"] = self.total_traj_completed
|
|
with open(osp.join(self.save_dir, 'metadata.json'), 'w') as f:
|
|
json.dump(self.metadata, f, indent= 4)
|
|
|
|
def wrap_up_trajectory(self, env_i, step_slice):
|
|
# wrap up from the rollout_storage based on `step_slice`. Thus, `step_slice` must include
|
|
# the `done` step if exist.
|
|
trajectory = dict(
|
|
privileged_observations= self.rollout_storage.privileged_observations[step_slice, env_i].cpu().numpy(),
|
|
actions= self.rollout_storage.actions[step_slice, env_i].cpu().numpy(),
|
|
rewards= self.rollout_storage.rewards[step_slice, env_i].cpu().numpy(),
|
|
dones= self.rollout_storage.dones[step_slice, env_i].cpu().numpy(),
|
|
values= self.rollout_storage.values[step_slice, env_i].cpu().numpy(),
|
|
)
|
|
# compress observations components if set
|
|
if not self.obs_disassemble_mapping is None:
|
|
observations = self.rollout_storage.observations[step_slice, env_i].cpu().numpy() # (n_steps, d_obs)
|
|
for component_name in self.metadata["obs_segments"].keys():
|
|
obs_slice = get_obs_slice(self.metadata["obs_segments"], component_name)
|
|
obs_component = observations[..., obs_slice[0]]
|
|
if component_name in self.obs_disassemble_mapping:
|
|
# compress the component
|
|
obs_component = getattr(
|
|
compresser,
|
|
"compress_" + self.obs_disassemble_mapping[component_name],
|
|
)(obs_component)
|
|
trajectory["obs_" + component_name] = obs_component
|
|
else:
|
|
trajectory["observations"] = self.rollout_storage.observations[step_slice, env_i].cpu().numpy(),
|
|
if self.transition_has_timeouts:
|
|
trajectory["timeouts"] = self.transition_timeouts[step_slice, env_i].cpu().numpy()
|
|
return trajectory
|
|
|
|
def update_traj_handler(self, env_i, step_slice):
|
|
""" update the trajectory file handler for the env_i """
|
|
# save the metadatas for current trajectory
|
|
traj_idx = self.traj_idxs[env_i]
|
|
|
|
if self.success_traj_only:
|
|
if self.rollout_storage.dones[step_slice.stop-1, env_i] and (not self.transition_timeouts[step_slice.stop-1, env_i]):
|
|
# done by termination not timeout (failed)
|
|
# remove all files in current trajectory directory
|
|
traj_dir = osp.join(self.save_dir, f"trajectory_{traj_idx}")
|
|
for f in os.listdir(traj_dir):
|
|
try:
|
|
if f.startswith("traj_"):
|
|
start_timestep, stop_timestep = f.split("_")[1:]
|
|
start_timestep = int(start_timestep)
|
|
stop_timestep = int(stop_timestep)
|
|
self.total_timesteps -= stop_timestep - start_timestep
|
|
except:
|
|
pass
|
|
os.remove(osp.join(traj_dir, f))
|
|
self.dumped_traj_lengths[env_i] = 0
|
|
return
|
|
|
|
# update the handlers to a new trajectory
|
|
# Also, skip the trajectory directory that has data collected before this run.
|
|
while len(os.listdir(osp.join(self.save_dir, f"trajectory_{traj_idx}"))) > 0:
|
|
traj_idx = max(self.traj_idxs) + 1
|
|
os.makedirs(osp.join(self.save_dir, f"trajectory_{traj_idx}"), exist_ok= True)
|
|
self.traj_idxs[env_i] = traj_idx
|
|
self.total_traj_completed += 1
|
|
self.dumped_traj_lengths[env_i] = 0
|
|
|
|
def save_steps(self):
|
|
""" dump a series or transitions to the file """
|
|
for rollout_env_i in range(self.rollout_storage.num_envs):
|
|
done_idxs = torch.where(self.rollout_storage.dones[:, rollout_env_i, 0])[0]
|
|
if len(done_idxs) == 0:
|
|
# dump the whole rollout for this env
|
|
self.dump_to_file(rollout_env_i, slice(0, self.rollout_storage.num_transitions_per_env))
|
|
else:
|
|
start_idx = 0
|
|
for di in range(done_idxs.shape[0]):
|
|
end_idx = done_idxs[di].item()
|
|
|
|
# dump and update the traj_idx for this env
|
|
self.dump_to_file(rollout_env_i, slice(start_idx, end_idx+1))
|
|
self.update_traj_handler(rollout_env_i, slice(start_idx, end_idx+1))
|
|
|
|
start_idx = end_idx + 1
|
|
if start_idx < self.rollout_storage.num_transitions_per_env:
|
|
self.dump_to_file(rollout_env_i, slice(start_idx, self.rollout_storage.num_transitions_per_env))
|
|
self.dump_metadata()
|
|
|
|
def collect_and_save(self, config= None):
|
|
""" Run the rolllout to collect the demonstration data and save it to the file """
|
|
# create directory and save metadata file
|
|
self.metadata = {
|
|
'config': config,
|
|
'env': self.env.__class__.__name__,
|
|
'policy': self.policy.__class__.__name__,
|
|
'rollout_storage_length': self.rollout_storage_length,
|
|
'success_traj_only': self.success_traj_only,
|
|
'min_timesteps': self.min_timesteps,
|
|
'min_episodes': self.min_episodes,
|
|
'use_critic_obs': self.use_critic_obs,
|
|
|
|
}
|
|
# create env-wise trajectory file handler
|
|
os.makedirs(self.save_dir, exist_ok= True)
|
|
self.init_traj_handlers()
|
|
self.init_storage_buffer()
|
|
|
|
with open(osp.join(self.save_dir, 'metadata.json'), 'w') as f:
|
|
# It will be refreshed once the collection is done.
|
|
json.dump(self.metadata, f, indent= 4)
|
|
|
|
# collect the demonstration data
|
|
self.env.reset()
|
|
obs = self.env.get_observations()
|
|
privileged_obs = self.env.get_privileged_observations()
|
|
critic_obs = privileged_obs if privileged_obs is not None else obs
|
|
self.obs, self.critic_obs = obs, critic_obs
|
|
while not self.check_stop():
|
|
for step_i in range(self.rollout_storage_length):
|
|
self.collect_step(step_i)
|
|
self.save_steps()
|
|
self.rollout_storage.clear()
|
|
self.print_log()
|
|
|
|
# close the trajectory file handlers
|
|
self.close()
|
|
|
|
def print_log(self):
|
|
""" print the log """
|
|
self.print_log_time = time.monotonic()
|
|
if hasattr(self, "last_print_log_time"):
|
|
print("time elapsed:", self.print_log_time - self.last_print_log_time)
|
|
print("throughput:", self.total_timesteps / (self.print_log_time - self.last_print_log_time))
|
|
self.last_print_log_time = self.print_log_time
|
|
print("total_timesteps:", self.total_timesteps)
|
|
print("total_trajectories", self.total_traj_completed)
|
|
|
|
def close(self):
|
|
""" check empty directories and remove them """
|
|
pass
|
|
|
|
def __del__(self):
|
|
""" Incase the process stops accedentally, close the file handlers """
|
|
for traj_idx in self.traj_idxs:
|
|
traj_dir = osp.join(self.save_dir, f"trajectory_{traj_idx}")
|
|
# remove the empty directories
|
|
if len(os.listdir(traj_dir)) == 0:
|
|
os.rmdir(traj_dir)
|
|
for timestep_count in self.dumped_traj_lengths:
|
|
self.total_timesteps += timestep_count
|
|
self.dump_metadata()
|
|
print(f"Saved dataset in {self.save_dir}")
|