backup wip

This commit is contained in:
Alexander Soare 2024-05-10 07:07:55 +01:00
parent 77aa80e198
commit fb202b5040
2 changed files with 3 additions and 5 deletions

View File

@ -9,7 +9,7 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def make_dataset(
cfg,
split="train",
):
) -> LeRobotDataset:
if cfg.env.name not in cfg.dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "

View File

@ -301,7 +301,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and scalar valued
Returns a dictionary with loss as a tensor, and other information as native floats.
"""
device = get_device_from_parameters(self)
@ -310,9 +310,6 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
info = {}
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
batch_size = batch["index"].shape[0]
# (b, t) -> (t, b)
for key in batch:
if batch[key].ndim > 1:
@ -340,6 +337,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# Run latent rollout using the latent dynamics model and policy model.
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`.
batch_size = batch["index"].shape[0]
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device)