Merge branch 'main' into user/rcadene/2024_05_02_visu_rerun

This commit is contained in:
Remi 2024-05-04 16:04:06 +02:00 committed by GitHub
commit ae79e7555e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 3 additions and 2 deletions

View File

@ -71,6 +71,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
environment. It works by managing the actions in a queue and only calling `select_actions` when the environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty. queue is empty.
""" """
assert "observation.images.top" in batch
assert "observation.state" in batch
self.eval() self.eval()
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)

View File

@ -115,7 +115,6 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
""" """
assert "observation.image" in batch assert "observation.image" in batch
assert "observation.state" in batch assert "observation.state" in batch
assert len(batch) == 2
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)

View File

@ -122,7 +122,6 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
"""Select a single action given environment observations.""" """Select a single action given environment observations."""
assert "observation.image" in batch assert "observation.image" in batch
assert "observation.state" in batch assert "observation.state" in batch
assert len(batch) == 2
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)