diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index ae82c361..34d92daa 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -145,6 +145,9 @@ class PushtDataset(torch.utils.data.Dataset): assert (episode_ids[idx0:idx1] == episode_id).all() image = imgs[idx0:idx1] + assert image.min() >= 0.0 + assert image.max() <= 255.0 + image = image.type(torch.uint8) state = states[idx0:idx1] agent_pos = state[:, :2] diff --git a/poetry.lock b/poetry.lock index 3c02e2f6..0133b3ed 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -921,7 +921,7 @@ shapely = "^2.0.3" type = "git" url = "git@github.com:huggingface/gym-pusht.git" reference = "HEAD" -resolved_reference = "824b22832cc8d71a4b4e96a57563510cf47e30c1" +resolved_reference = "080d4ce4d8d3140b2fd204ed628bda14dc58ff06" [[package]] name = "gym-xarm" diff --git a/tests/data/pusht/data_dict.pth b/tests/data/pusht/data_dict.pth index 40d96a51..a083c86c 100644 Binary files a/tests/data/pusht/data_dict.pth and b/tests/data/pusht/data_dict.pth differ diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 37e320a2..71eefa9c 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -51,6 +51,7 @@ def test_factory(env_name, dataset_id, policy_name): keys_ndim_required.append( (key, 3, True), ) + assert dataset.data_dict[key].dtype == torch.uint8, f"{key}" # test number of dimensions for key, ndim, required in keys_ndim_required: