fix pusht images type from float32 to uint8, update gym-pusht dependencies

This commit is contained in:
Cadene 2024-04-11 14:29:16 +00:00
parent 4216636084
commit c1a618e567
4 changed files with 6 additions and 2 deletions

View File

@ -145,6 +145,9 @@ class PushtDataset(torch.utils.data.Dataset):
assert (episode_ids[idx0:idx1] == episode_id).all()
image = imgs[idx0:idx1]
assert image.min() >= 0.0
assert image.max() <= 255.0
image = image.type(torch.uint8)
state = states[idx0:idx1]
agent_pos = state[:, :2]

4
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
[[package]]
name = "absl-py"
@ -921,7 +921,7 @@ shapely = "^2.0.3"
type = "git"
url = "git@github.com:huggingface/gym-pusht.git"
reference = "HEAD"
resolved_reference = "824b22832cc8d71a4b4e96a57563510cf47e30c1"
resolved_reference = "080d4ce4d8d3140b2fd204ed628bda14dc58ff06"
[[package]]
name = "gym-xarm"

Binary file not shown.

View File

@ -51,6 +51,7 @@ def test_factory(env_name, dataset_id, policy_name):
keys_ndim_required.append(
(key, 3, True),
)
assert dataset.data_dict[key].dtype == torch.uint8, f"{key}"
# test number of dimensions
for key, ndim, required in keys_ndim_required: