parkour/rsl_rl/rsl_rl/storage/rollout_dataset.py

360 lines
17 KiB
Python

import os
import os.path as osp
import pickle
from collections import namedtuple, OrderedDict
import json
import random
import numpy as np
import torch
from torch.utils.data import IterableDataset, get_worker_info
import rsl_rl.utils.data_compresser as compresser
class RolloutDataset(IterableDataset):
Transitions = namedtuple("Transitions", [
"observation", "privileged_observation", "action", "reward", "done",
])
def __init__(self,
data_dir= None,
scan_dir= None,
num_envs= 1,
dataset_loops: int= 1,
subset_traj= None, # (start_idx, end_idx) as a slice
random_shuffle_traj_order= False, # If True, the traj_data will be loaded directoy to rl_device in a random order
keep_latest_ratio= 1.0, # If < 1., only keeps a certain ratio of the latest trajectories
keep_latest_n_trajs= 0, # If > 0 and more than n_trajectories, ignores keep_latest_ratio and keeps the latest n trajectories.
starting_frame_range= [0, 1], # if set, the starting timestep will be uniformly chose from this, when each new trajectory is loaded.
# if sampled starting frame is bigger than the trajectory length, starting frame will be 0
load_data_to_device= True, # If True, the traj_data will be loaded directoy to rl_device rather than np array
rl_device= "cpu",
):
""" choose data_dir or scan_dir, but not both. If scan_dir is chosen, the dataset will scan the
directory and treat each direct subdirectory as a dataset everytime it is initialized.
"""
self.data_dir = data_dir
self.scan_dir = scan_dir
self.num_envs = num_envs
self.max_loops = dataset_loops
self.subset_traj = subset_traj
self.random_shuffle_traj_order = random_shuffle_traj_order
self.keep_latest_ratio = keep_latest_ratio
self.keep_latest_n_trajs = keep_latest_n_trajs
self.starting_frame_range = starting_frame_range
self.load_data_to_device = load_data_to_device
self.rl_device = rl_device
# check arguments
assert not (self.data_dir is None and self.scan_dir is None), "data_dir and scan_dir cannot be both None"
self.num_looped = 0
def initialize(self):
self.load_dataset_directory()
if self.subset_traj is not None:
self.unused_traj_dirs = self.unused_traj_dirs[self.subset_traj[0]: self.subset_traj[1]]
if self.keep_latest_ratio < 1. or self.keep_latest_n_trajs > 0:
self.unused_traj_dirs = sorted(
self.unused_traj_dirs,
key= lambda x: os.stat(x).st_ctime,
)
if self.keep_latest_n_trajs > 0:
self.unused_traj_dirs = self.unused_traj_dirs[-self.keep_latest_n_trajs:]
else:
self.unused_traj_dirs = self.unused_traj_dirs[int(len(self.unused_traj_dirs) * self.keep_latest_ratio):]
print("Using a subset of trajectories, total number of trajectories: ", len(self.unused_traj_dirs))
if self.random_shuffle_traj_order:
random.shuffle(self.unused_traj_dirs)
# attributes that handles trajectory files for each env
self.current_traj_dirs = [None for _ in range(self.num_envs)]
self.trajectory_files = [[] for _ in range(self.num_envs)]
self.traj_file_idxs = np.zeros(self.num_envs, dtype= np.int32)
self.traj_step_idxs = np.zeros(self.num_envs, dtype= np.int32)
self.traj_datas = [None for _ in range(self.num_envs)]
env_idx = 0
while env_idx < self.num_envs:
if len(self.unused_traj_dirs) == 0:
print("Not enough trajectories, waiting to re-initialize. Press Enter to continue....")
input()
self.initialize()
return
starting_frame = torch.randint(self.starting_frame_range[0], self.starting_frame_range[1], (1,)).item()
update_result = self.update_traj_handle(env_idx, self.unused_traj_dirs.pop(0), starting_frame)
if update_result:
env_idx += 1
self.dataset_drained = False
def update_traj_handle(self, env_idx, traj_dir, starting_step_idx= 0):
""" Load and update the trajectory handle for a given env_idx.
Also update traj_step_idxs.
Return whether the trajectory is successfully loaded
"""
self.current_traj_dirs[env_idx] = traj_dir
try:
self.trajectory_files[env_idx] = sorted(
os.listdir(self.current_traj_dirs[env_idx]),
key= lambda x: int(x.split("_")[1]),
)
self.traj_file_idxs[env_idx] = 0
except:
self.nullify_traj_handles(env_idx)
return False
self.traj_datas[env_idx] = self.load_traj_data(
env_idx,
self.traj_file_idxs[env_idx],
new_episode= True,
)
if self.traj_datas[env_idx] is None:
self.nullify_traj_handles(env_idx)
return False
# The number in the file name is the timestep slice
current_file_max_timestep = int(self.trajectory_files[env_idx][self.traj_file_idxs[env_idx]].split(".")[0].split("_")[2]) - 1
while current_file_max_timestep < starting_step_idx:
self.traj_file_idxs[env_idx] += 1
if self.traj_file_idxs[env_idx] >= len(self.trajectory_files[env_idx]):
# trajectory length is shorter than starting_step_idx, set starting_step_idx to 0
starting_step_idx = 0
self.traj_file_idxs[env_idx] = 0
break
current_file_max_timestep = int(self.trajectory_files[env_idx][self.traj_file_idxs[env_idx]].split(".")[0].split("_")[2]) - 1
current_file_min_step = int(self.trajectory_files[env_idx][self.traj_file_idxs[env_idx]].split(".")[0].split("_")[1])
self.traj_step_idxs[env_idx] = starting_step_idx - current_file_min_step
if self.traj_file_idxs[env_idx] > 0:
# reload the traj_data because traj_file_idxs is updated
self.traj_datas[env_idx] = self.load_traj_data(
env_idx,
self.traj_file_idxs[env_idx],
new_episode= True,
)
if self.traj_datas[env_idx] is None:
self.nullify_traj_handles(env_idx)
return False
return True
def nullify_traj_handles(self, env_idx):
self.current_traj_dirs[env_idx] = ""
self.trajectory_files[env_idx] = []
self.traj_file_idxs[env_idx] = 0
self.traj_step_idxs[env_idx] = 0
self.traj_datas[env_idx] = None
def load_dataset_directory(self):
if self.scan_dir is not None:
if not osp.isdir(self.scan_dir):
print("RolloutDataset: scan_dir {} does not exist, creating...".format(self.scan_dir))
os.makedirs(self.scan_dir)
self.data_dir = sorted([
osp.join(self.scan_dir, x) \
for x in os.listdir(self.scan_dir) \
if osp.isdir(osp.join(self.scan_dir, x)) and osp.isfile(osp.join(self.scan_dir, x, "metadata.json"))
])
if isinstance(self.data_dir, list):
total_timesteps = 0
self.unused_traj_dirs = []
for data_dir in self.data_dir:
try:
new_trajectories = sorted([
osp.join(data_dir, x) \
for x in os.listdir(data_dir) \
if x.startswith("trajectory_") and len(os.listdir(osp.join(data_dir, x))) > 0
], key= lambda x: int(x.split("_")[-1]))
except:
continue
self.unused_traj_dirs.extend(new_trajectories)
try:
with open(osp.join(data_dir, "metadata.json"), "r") as f:
self.metadata = json.load(f, object_pairs_hook= OrderedDict)
total_timesteps += self.metadata["total_timesteps"]
except:
pass # skip
print("RolloutDataset: Loaded data from multiple directories. The metadata is from the last directory.")
print("RolloutDataset: Total number of timesteps: ", total_timesteps)
print("RolloutDataset: Total number of trajectories: ", len(self.unused_traj_dirs))
else:
self.unused_traj_dirs = sorted([
osp.join(self.data_dir, x) \
for x in os.listdir(self.data_dir) \
if x.startswith("trajectory_") and len(os.listdir(osp.join(self.data_dir, x))) > 0
], key= lambda x: int(x.split("_")[-1]))
with open(osp.join(self.data_dir, "metadata.json"), "r") as f:
self.metadata = json.load(f, object_pairs_hook= OrderedDict)
# check if this dataset is initialized in worker process
worker_info = get_worker_info()
if worker_info is not None:
self.dataset_loops = 1 # Let the sampler handle the loops
worker_id = worker_info.id
num_workers = worker_info.num_workers
trajs_per_worker = len(self.unused_traj_dirs) // num_workers
self.unused_traj_dirs = self.unused_traj_dirs[worker_id * trajs_per_worker: (worker_id + 1) * trajs_per_worker]
if worker_id == num_workers - 1:
self.unused_traj_dirs.extend(self.unused_traj_dirs[:(len(self.unused_traj_dirs) % num_workers)])
print("RolloutDataset: Worker {} of {} initialized with {} trajectories".format(
worker_id, num_workers, len(self.unused_traj_dirs)
))
def assmeble_obs_components(self, traj_data):
assert "obs_segments" in self.metadata, "Corrupted metadata, obs_segments not found in metadata"
observations = []
for component_name in self.metadata["obs_segments"].keys():
obs_component = traj_data.pop("obs_" + component_name)
if component_name in self.metadata["obs_disassemble_mapping"]:
obs_component = getattr(
compresser,
"decompress_" + self.metadata["obs_disassemble_mapping"][component_name],
)(obs_component)
observations.append(obs_component)
traj_data["observations"] = np.concatenate(observations, axis= -1) # (n_steps, d_obs)
return traj_data
def load_traj_data(self, env_idx, traj_file_idx, new_episode= False):
""" If new_episode, set the 0-th frame to done, making sure the agent is reset.
"""
traj_dir = self.current_traj_dirs[env_idx]
try:
with open(osp.join(traj_dir, self.trajectory_files[env_idx][traj_file_idx]), "rb") as f:
traj_data = pickle.load(f)
except:
try:
traj_file = osp.join(traj_dir, self.trajectory_files[env_idx][traj_file_idx])
print("Failed to load", traj_file)
except:
print("Failed to load file")
# The caller will know that the file is abscent, then switch to a new trajectory
return None
# connect the observation components if they are disassambled in pickle files
if "obs_disassemble_mapping" in self.metadata:
traj_data = self.assmeble_obs_components(traj_data)
if self.load_data_to_device:
for data_key, data_val in traj_data.items():
traj_data[data_key] = torch.from_numpy(data_val).to(self.rl_device)
if new_episode:
# add done flag to the 0-th step of newly loaded trajectory
traj_data["dones"][0] = True
return traj_data
def get_transition_batch(self):
if not hasattr(self, "dataset_drained"):
# initialize the dataset if it is not used as a iterator
self.initialize()
observations = []
privileged_observations = []
actions = []
rewards = []
dones = []
time_outs = []
if self.dataset_drained:
return None, None
for env_idx in range(self.num_envs):
traj_data = self.traj_datas[env_idx]
traj_step_idx = self.traj_step_idxs[env_idx]
observations.append(traj_data["observations"][traj_step_idx])
privileged_observations.append(traj_data["privileged_observations"][traj_step_idx])
actions.append(traj_data["actions"][traj_step_idx])
rewards.append(traj_data["rewards"][traj_step_idx])
dones.append(traj_data["dones"][traj_step_idx])
if "timeouts" in traj_data: time_outs.append(traj_data["timeouts"][traj_step_idx])
self.traj_step_idxs[env_idx] += 1
traj_update_result = self.update_traj_data_if_needed(env_idx)
if traj_update_result == "drained":
self.dataset_drained = True
return None, None
elif traj_update_result == "new_traj":
dones[-1][:] = True
if torch.is_tensor(observations[0]):
observations = torch.stack(observations)
else:
observations = torch.from_numpy(np.stack(observations)).to(self.rl_device)
if torch.is_tensor(privileged_observations[0]):
privileged_observations = torch.stack(privileged_observations)
else:
privileged_observations = torch.from_numpy(np.stack(privileged_observations)).to(self.rl_device)
if torch.is_tensor(actions[0]):
actions = torch.stack(actions)
else:
actions = torch.from_numpy(np.stack(actions)).to(self.rl_device)
if torch.is_tensor(rewards[0]):
rewards = torch.stack(rewards).squeeze(-1) # to remove the last dimension as the simulator env
else:
rewards = torch.from_numpy(np.stack(rewards)).to(self.rl_device).squeeze(-1)
if torch.is_tensor(dones[0]):
dones = torch.stack(dones).to(bool).squeeze(-1)
else:
dones = torch.from_numpy(np.stack(dones)).to(self.rl_device).to(bool).squeeze(-1)
infos = dict()
if time_outs:
if torch.is_tensor(time_outs[0]):
infos["time_outs"] = torch.stack(time_outs)
else:
infos["time_outs"] = torch.from_numpy(np.stack(time_outs)).to(self.rl_device)
infos["num_looped"] = self.num_looped
return self.Transitions(
observation= observations,
privileged_observation= privileged_observations,
action= actions,
reward= rewards,
done= dones,
), infos
def update_traj_data_if_needed(self, env_idx):
""" Return 'new_file', 'new_traj', 'drained', or None
"""
traj_data = self.traj_datas[env_idx]
if self.traj_step_idxs[env_idx] >= len(traj_data["rewards"]):
# to next file
self.traj_file_idxs[env_idx] += 1
self.traj_step_idxs[env_idx] = 0
traj_data = None
new_episode = False
while traj_data is None:
if self.traj_file_idxs[env_idx] >= len(self.trajectory_files[env_idx]):
# to next trajectory
if len(self.unused_traj_dirs) == 0 or not osp.isdir(self.unused_traj_dirs[0]):
if self.max_loops > 0 and self.num_looped >= self.max_loops:
return 'drained'
else:
self.num_looped += 1
self.initialize()
return 'new_traj'
starting_frame = torch.randint(self.starting_frame_range[0], self.starting_frame_range[1], (1,)).item()
self.update_traj_handle(env_idx, self.unused_traj_dirs.pop(0), starting_frame)
traj_data = self.traj_datas[env_idx]
else:
traj_data = self.load_traj_data(
env_idx,
self.traj_file_idxs[env_idx],
new_episode= new_episode,
)
if traj_data is None:
self.nullify_traj_handles(env_idx)
else:
self.traj_datas[env_idx] = traj_data
return 'new_file'
return None
def set_traj_idx(self, traj_idx, env_idx= 0):
""" Allow users to select a specific trajectory to start from """
self.current_traj_dirs[env_idx] = self.unused_traj_dirs[traj_idx]
self.traj_file_idxs[env_idx] = 0
self.traj_step_idxs[env_idx] = 0
self.trajectory_files[env_idx] = sorted(
os.listdir(self.current_traj_dirs[env_idx]),
key= lambda x: int(x.split("_")[1]),
)
self.traj_datas[env_idx] = self.load_traj_data(env_idx, self.traj_file_idxs[env_idx])
self.dataset_drained = False
##### Interfaces for the IterableDataset #####
def __iter__(self):
self.initialize()
transition_batch, infos = self.get_transition_batch()
while transition_batch is not None:
yield transition_batch, infos
transition_batch, infos = self.get_transition_batch()