Implement review feedback

This commit is contained in:
Radek Osmulski 2024-05-28 22:10:15 +10:00
parent e7abcc2ffd
commit 566a8aa98e
2 changed files with 44 additions and 5 deletions

View File

@ -13,18 +13,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from typing import Iterator, Union
from torch.utils.data import SubsetRandomSampler
import torch
class EpisodeAwareSampler(SubsetRandomSampler):
class EpisodeAwareSampler:
def __init__(
self,
episode_data_index: dict,
episode_indices_to_use: Union[list, None] = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = True,
):
"""Sampler that optionally incorporates episode boundary information.
@ -34,14 +35,27 @@ class EpisodeAwareSampler(SubsetRandomSampler):
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames (int, optional): Number of frames to drop from the start of each episode. Defaults to 0.
drop_n_last_frames (int, optional): Number of frames to drop from the end of each episode. Defaults to 0.
shuffle (bool, optional): Whether to shuffle the indices. Defaults to True.
"""
indices = []
for episode_idx, (start_index, end_index) in enumerate(
zip(episode_data_index["from"], episode_data_index["to"], strict=False)
zip(episode_data_index["from"], episode_data_index["to"], strict=True)
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
indices.extend(
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
)
super().__init__(indices)
self.indices = indices
self.shuffle = shuffle
def __iter__(self) -> Iterator[int]:
if self.shuffle:
for i in torch.randperm(len(self.indices)):
yield self.indices[i]
else:
for i in self.indices:
yield i
def __len__(self) -> int:
return len(self.indices)

View File

@ -34,6 +34,8 @@ def test_drop_n_first_frames():
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
assert sampler.indices == [1, 4, 5]
assert len(sampler) == 3
assert set(sampler) == {1, 4, 5}
def test_drop_n_last_frames():
@ -48,6 +50,8 @@ def test_drop_n_last_frames():
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
assert sampler.indices == [0, 3, 4]
assert len(sampler) == 3
assert set(sampler) == {0, 3, 4}
def test_episode_indices_to_use():
@ -62,3 +66,24 @@ def test_episode_indices_to_use():
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
assert sampler.indices == [0, 1, 3, 4, 5]
assert len(sampler) == 5
assert set(sampler) == {0, 1, 3, 4, 5}
def test_shuffle():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, shuffle=False)
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert list(sampler) == [0, 1, 2, 3, 4, 5]
sampler = EpisodeAwareSampler(episode_data_index, shuffle=True)
assert len(sampler) == 6
assert set(sampler) == {0, 1, 2, 3, 4, 5}