backup wip
This commit is contained in:
parent
77aa80e198
commit
fb202b5040
|
@ -9,7 +9,7 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|||
def make_dataset(
|
||||
cfg,
|
||||
split="train",
|
||||
):
|
||||
) -> LeRobotDataset:
|
||||
if cfg.env.name not in cfg.dataset_repo_id:
|
||||
logging.warning(
|
||||
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
|
||||
|
|
|
@ -301,7 +301,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
Returns a dictionary with loss as a tensor, and scalar valued
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
"""
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
|
@ -310,9 +310,6 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
|
||||
info = {}
|
||||
|
||||
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
|
||||
batch_size = batch["index"].shape[0]
|
||||
|
||||
# (b, t) -> (t, b)
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
|
@ -340,6 +337,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||
# gives us a next `z`.
|
||||
batch_size = batch["index"].shape[0]
|
||||
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||
z_preds[0] = self.model.encode(current_observation)
|
||||
reward_preds = torch.empty_like(reward, device=device)
|
||||
|
|
Loading…
Reference in New Issue