fix pusht images type from float32 to uint8, update gym-pusht dependencies
This commit is contained in:
parent
4216636084
commit
c1a618e567
|
@ -145,6 +145,9 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||||
assert (episode_ids[idx0:idx1] == episode_id).all()
|
assert (episode_ids[idx0:idx1] == episode_id).all()
|
||||||
|
|
||||||
image = imgs[idx0:idx1]
|
image = imgs[idx0:idx1]
|
||||||
|
assert image.min() >= 0.0
|
||||||
|
assert image.max() <= 255.0
|
||||||
|
image = image.type(torch.uint8)
|
||||||
|
|
||||||
state = states[idx0:idx1]
|
state = states[idx0:idx1]
|
||||||
agent_pos = state[:, :2]
|
agent_pos = state[:, :2]
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "absl-py"
|
name = "absl-py"
|
||||||
|
@ -921,7 +921,7 @@ shapely = "^2.0.3"
|
||||||
type = "git"
|
type = "git"
|
||||||
url = "git@github.com:huggingface/gym-pusht.git"
|
url = "git@github.com:huggingface/gym-pusht.git"
|
||||||
reference = "HEAD"
|
reference = "HEAD"
|
||||||
resolved_reference = "824b22832cc8d71a4b4e96a57563510cf47e30c1"
|
resolved_reference = "080d4ce4d8d3140b2fd204ed628bda14dc58ff06"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gym-xarm"
|
name = "gym-xarm"
|
||||||
|
|
Binary file not shown.
|
@ -51,6 +51,7 @@ def test_factory(env_name, dataset_id, policy_name):
|
||||||
keys_ndim_required.append(
|
keys_ndim_required.append(
|
||||||
(key, 3, True),
|
(key, 3, True),
|
||||||
)
|
)
|
||||||
|
assert dataset.data_dict[key].dtype == torch.uint8, f"{key}"
|
||||||
|
|
||||||
# test number of dimensions
|
# test number of dimensions
|
||||||
for key, ndim, required in keys_ndim_required:
|
for key, ndim, required in keys_ndim_required:
|
||||||
|
|
Loading…
Reference in New Issue