fix pusht images type from float32 to uint8, update gym-pusht dependencies
This commit is contained in:
parent
4216636084
commit
c1a618e567
|
@ -145,6 +145,9 @@ class PushtDataset(torch.utils.data.Dataset):
|
|||
assert (episode_ids[idx0:idx1] == episode_id).all()
|
||||
|
||||
image = imgs[idx0:idx1]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
|
||||
state = states[idx0:idx1]
|
||||
agent_pos = state[:, :2]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
|
@ -921,7 +921,7 @@ shapely = "^2.0.3"
|
|||
type = "git"
|
||||
url = "git@github.com:huggingface/gym-pusht.git"
|
||||
reference = "HEAD"
|
||||
resolved_reference = "824b22832cc8d71a4b4e96a57563510cf47e30c1"
|
||||
resolved_reference = "080d4ce4d8d3140b2fd204ed628bda14dc58ff06"
|
||||
|
||||
[[package]]
|
||||
name = "gym-xarm"
|
||||
|
|
Binary file not shown.
|
@ -51,6 +51,7 @@ def test_factory(env_name, dataset_id, policy_name):
|
|||
keys_ndim_required.append(
|
||||
(key, 3, True),
|
||||
)
|
||||
assert dataset.data_dict[key].dtype == torch.uint8, f"{key}"
|
||||
|
||||
# test number of dimensions
|
||||
for key, ndim, required in keys_ndim_required:
|
||||
|
|
Loading…
Reference in New Issue