Implement 2nd round of review changes
This commit is contained in:
parent
566a8aa98e
commit
621f69d98f
|
@ -25,7 +25,7 @@ class EpisodeAwareSampler:
|
|||
episode_indices_to_use: Union[list, None] = None,
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = True,
|
||||
shuffle: bool = False,
|
||||
):
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
|
||||
|
@ -35,7 +35,7 @@ class EpisodeAwareSampler:
|
|||
Assumes that episodes are indexed from 0 to N-1.
|
||||
drop_n_first_frames (int, optional): Number of frames to drop from the start of each episode. Defaults to 0.
|
||||
drop_n_last_frames (int, optional): Number of frames to drop from the end of each episode. Defaults to 0.
|
||||
shuffle (bool, optional): Whether to shuffle the indices. Defaults to True.
|
||||
shuffle (bool, optional): Whether to shuffle the indices. Defaults to False.
|
||||
"""
|
||||
indices = []
|
||||
for episode_idx, (start_index, end_index) in enumerate(
|
||||
|
|
|
@ -355,6 +355,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
sampler=EpisodeAwareSampler(
|
||||
offline_dataset.episode_data_index,
|
||||
drop_n_last_frames=cfg.drop_n_last_frames if hasattr(cfg, "drop_n_last_frames") else 0,
|
||||
shuffle=True,
|
||||
),
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=False,
|
||||
|
|
|
@ -35,7 +35,7 @@ def test_drop_n_first_frames():
|
|||
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
|
||||
assert sampler.indices == [1, 4, 5]
|
||||
assert len(sampler) == 3
|
||||
assert set(sampler) == {1, 4, 5}
|
||||
assert list(sampler) == [1, 4, 5]
|
||||
|
||||
|
||||
def test_drop_n_last_frames():
|
||||
|
@ -51,7 +51,7 @@ def test_drop_n_last_frames():
|
|||
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
|
||||
assert sampler.indices == [0, 3, 4]
|
||||
assert len(sampler) == 3
|
||||
assert set(sampler) == {0, 3, 4}
|
||||
assert list(sampler) == [0, 3, 4]
|
||||
|
||||
|
||||
def test_episode_indices_to_use():
|
||||
|
@ -67,7 +67,7 @@ def test_episode_indices_to_use():
|
|||
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
|
||||
assert sampler.indices == [0, 1, 3, 4, 5]
|
||||
assert len(sampler) == 5
|
||||
assert set(sampler) == {0, 1, 3, 4, 5}
|
||||
assert list(sampler) == [0, 1, 3, 4, 5]
|
||||
|
||||
|
||||
def test_shuffle():
|
||||
|
|
Loading…
Reference in New Issue