Implement 2nd round of review changes

This commit is contained in:
Radek Osmulski 2024-05-28 23:14:08 +10:00
parent 566a8aa98e
commit 621f69d98f
3 changed files with 6 additions and 5 deletions

View File

@ -25,7 +25,7 @@ class EpisodeAwareSampler:
episode_indices_to_use: Union[list, None] = None, episode_indices_to_use: Union[list, None] = None,
drop_n_first_frames: int = 0, drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0, drop_n_last_frames: int = 0,
shuffle: bool = True, shuffle: bool = False,
): ):
"""Sampler that optionally incorporates episode boundary information. """Sampler that optionally incorporates episode boundary information.
@ -35,7 +35,7 @@ class EpisodeAwareSampler:
Assumes that episodes are indexed from 0 to N-1. Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames (int, optional): Number of frames to drop from the start of each episode. Defaults to 0. drop_n_first_frames (int, optional): Number of frames to drop from the start of each episode. Defaults to 0.
drop_n_last_frames (int, optional): Number of frames to drop from the end of each episode. Defaults to 0. drop_n_last_frames (int, optional): Number of frames to drop from the end of each episode. Defaults to 0.
shuffle (bool, optional): Whether to shuffle the indices. Defaults to True. shuffle (bool, optional): Whether to shuffle the indices. Defaults to False.
""" """
indices = [] indices = []
for episode_idx, (start_index, end_index) in enumerate( for episode_idx, (start_index, end_index) in enumerate(

View File

@ -355,6 +355,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
sampler=EpisodeAwareSampler( sampler=EpisodeAwareSampler(
offline_dataset.episode_data_index, offline_dataset.episode_data_index,
drop_n_last_frames=cfg.drop_n_last_frames if hasattr(cfg, "drop_n_last_frames") else 0, drop_n_last_frames=cfg.drop_n_last_frames if hasattr(cfg, "drop_n_last_frames") else 0,
shuffle=True,
), ),
pin_memory=device.type != "cpu", pin_memory=device.type != "cpu",
drop_last=False, drop_last=False,

View File

@ -35,7 +35,7 @@ def test_drop_n_first_frames():
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1) sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
assert sampler.indices == [1, 4, 5] assert sampler.indices == [1, 4, 5]
assert len(sampler) == 3 assert len(sampler) == 3
assert set(sampler) == {1, 4, 5} assert list(sampler) == [1, 4, 5]
def test_drop_n_last_frames(): def test_drop_n_last_frames():
@ -51,7 +51,7 @@ def test_drop_n_last_frames():
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1) sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
assert sampler.indices == [0, 3, 4] assert sampler.indices == [0, 3, 4]
assert len(sampler) == 3 assert len(sampler) == 3
assert set(sampler) == {0, 3, 4} assert list(sampler) == [0, 3, 4]
def test_episode_indices_to_use(): def test_episode_indices_to_use():
@ -67,7 +67,7 @@ def test_episode_indices_to_use():
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2]) sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
assert sampler.indices == [0, 1, 3, 4, 5] assert sampler.indices == [0, 1, 3, 4, 5]
assert len(sampler) == 5 assert len(sampler) == 5
assert set(sampler) == {0, 1, 3, 4, 5} assert list(sampler) == [0, 1, 3, 4, 5]
def test_shuffle(): def test_shuffle():