add reward
Signed-off-by: youliangtan <tan_you_liang@hotmail.com>
This commit is contained in:
parent
3e4d7beb5d
commit
61e51c9fe4
|
@ -103,6 +103,7 @@ def load_from_raw(
|
|||
|
||||
states = []
|
||||
actions = []
|
||||
rewards = torch.zeros(num_frames, dtype=torch.float32)
|
||||
ep_dict = {}
|
||||
|
||||
image_array_dict = {key: [] for key in image_keys}
|
||||
|
@ -112,9 +113,9 @@ def load_from_raw(
|
|||
for j, step in enumerate(steps):
|
||||
states.append(tf_to_torch(step['observation']['state']))
|
||||
actions.append(tf_to_torch(step['action']))
|
||||
rewards[j] = torch.tensor(step['reward'].numpy(), dtype=torch.float32)
|
||||
|
||||
# if "language_text" in step:
|
||||
# print(" - lang: ", step["language_text"])
|
||||
# TODO: language_text, is_terminal, is_last etc.
|
||||
|
||||
for im_key in image_keys:
|
||||
if im_key not in step['observation']:
|
||||
|
@ -156,6 +157,7 @@ def load_from_raw(
|
|||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["reward"] = rewards
|
||||
ep_dict["next.done"] = done
|
||||
|
||||
ep_dicts.append(ep_dict)
|
||||
|
@ -198,6 +200,7 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
|||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["reward"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
|
||||
|
@ -229,8 +232,8 @@ def from_raw_to_lerobot_format(
|
|||
|
||||
if __name__ == "__main__":
|
||||
# TODO (YL) remove this
|
||||
raw_dir = Path("/hdd/serl/serl_task1_combine_13jun/")
|
||||
videos_dir = Path("/hdd/serl/tmp/")
|
||||
raw_dir = Path("/hdd/tensorflow_datasets/austin_buds_dataset_converted_externally_to_rlds/0.1.0/")
|
||||
videos_dir = Path("/hdd/tmp/")
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||
raw_dir, videos_dir, fps=5, video=True, episodes=None,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue