Remove batch length restrictions in select_action (#123)
This commit is contained in:
parent
bccee745c3
commit
c015252e20
|
@ -71,6 +71,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||||
queue is empty.
|
queue is empty.
|
||||||
"""
|
"""
|
||||||
|
assert "observation.images.top" in batch
|
||||||
|
assert "observation.state" in batch
|
||||||
|
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
|
|
@ -115,7 +115,6 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
"""
|
"""
|
||||||
assert "observation.image" in batch
|
assert "observation.image" in batch
|
||||||
assert "observation.state" in batch
|
assert "observation.state" in batch
|
||||||
assert len(batch) == 2
|
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
|
||||||
|
|
|
@ -122,7 +122,6 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
"""Select a single action given environment observations."""
|
"""Select a single action given environment observations."""
|
||||||
assert "observation.image" in batch
|
assert "observation.image" in batch
|
||||||
assert "observation.state" in batch
|
assert "observation.state" in batch
|
||||||
assert len(batch) == 2
|
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue