Implement 2nd round of review changes

This commit is contained in:
Radek Osmulski 2024-05-28 23:14:08 +10:00
parent 566a8aa98e
commit 621f69d98f
3 changed files with 6 additions and 5 deletions

View File

@ -25,7 +25,7 @@ class EpisodeAwareSampler:
episode_indices_to_use: Union[list, None] = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = True,
shuffle: bool = False,
):
"""Sampler that optionally incorporates episode boundary information.
@ -35,7 +35,7 @@ class EpisodeAwareSampler:
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames (int, optional): Number of frames to drop from the start of each episode. Defaults to 0.
drop_n_last_frames (int, optional): Number of frames to drop from the end of each episode. Defaults to 0.
shuffle (bool, optional): Whether to shuffle the indices. Defaults to True.
shuffle (bool, optional): Whether to shuffle the indices. Defaults to False.
"""
indices = []
for episode_idx, (start_index, end_index) in enumerate(

View File

@ -355,6 +355,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
sampler=EpisodeAwareSampler(
offline_dataset.episode_data_index,
drop_n_last_frames=cfg.drop_n_last_frames if hasattr(cfg, "drop_n_last_frames") else 0,
shuffle=True,
),
pin_memory=device.type != "cpu",
drop_last=False,

View File

@ -35,7 +35,7 @@ def test_drop_n_first_frames():
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
assert sampler.indices == [1, 4, 5]
assert len(sampler) == 3
assert set(sampler) == {1, 4, 5}
assert list(sampler) == [1, 4, 5]
def test_drop_n_last_frames():
@ -51,7 +51,7 @@ def test_drop_n_last_frames():
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
assert sampler.indices == [0, 3, 4]
assert len(sampler) == 3
assert set(sampler) == {0, 3, 4}
assert list(sampler) == [0, 3, 4]
def test_episode_indices_to_use():
@ -67,7 +67,7 @@ def test_episode_indices_to_use():
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
assert sampler.indices == [0, 1, 3, 4, 5]
assert len(sampler) == 5
assert set(sampler) == {0, 1, 3, 4, 5}
assert list(sampler) == [0, 1, 3, 4, 5]
def test_shuffle():