From 918868162eaae31c947812877c8656371facdad0 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 16 May 2024 09:51:01 +0100 Subject: [PATCH] backup --- .../common/policies/tdmpc/modeling_tdmpc.py | 75 ++++++++++++------- 1 file changed, 47 insertions(+), 28 deletions(-) diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index c8a4d8c9..ed245618 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -13,7 +13,6 @@ TODO(alexander-soare): Use batch-first throughout. import logging from collections import deque from copy import deepcopy -from functools import partial from typing import Callable import einops @@ -308,6 +307,17 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) + import os + from time import sleep + + while True: + if not os.path.exists("/tmp/mutex.txt"): + sleep(0.01) + continue + batch.update(torch.load("/tmp/batch.pth")) + os.remove("/tmp/mutex.txt") + break + info = {} # (b, t) -> (t, b) @@ -315,16 +325,16 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): if batch[key].ndim > 1: batch[key] = batch[key].transpose(1, 0) - action = batch["action"] # (t, b) - reward = batch["next.reward"] # (t,) + action = batch["action"] # (t, b, action_dim) + reward = batch["next.reward"] # (t, b) observations = {k: v for k, v in batch.items() if k.startswith("observation.")} - # Apply random image augmentations. - if self.config.max_random_shift_ratio > 0: - observations["observation.image"] = flatten_forward_unflatten( - partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio), - observations["observation.image"], - ) + # # Apply random image augmentations. + # if self.config.max_random_shift_ratio > 0: + # observations["observation.image"] = flatten_forward_unflatten( + # partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio), + # observations["observation.image"], + # ) # Get the current observation for predicting trajectories, and all future observations for use in # the latent consistency loss and TD loss. @@ -332,37 +342,44 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): for k in observations: current_observation[k] = observations[k][0] next_observations[k] = observations[k][1:] - horizon = next_observations["observation.image"].shape[0] + horizon, batch_size = next_observations["observation.image"].shape[:2] # Run latent rollout using the latent dynamics model and policy model. # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action # gives us a next `z`. - batch_size = batch["index"].shape[0] z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device) - z_preds[0] = self.model.encode(current_observation) + z_preds[0] = self.model.encode(current_observation) # TODO(now): Same reward_preds = torch.empty_like(reward, device=device) for t in range(horizon): - z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t]) + z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward( + z_preds[t], action[t] + ) # TODO(now): same # Compute Q and V value predictions based on the latent rollout. - q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch) - v_preds = self.model.V(z_preds[:-1]) + q_preds_ensemble = self.model.Qs( + z_preds[:-1], action + ) # (ensemble, horizon, batch) # TODO(now): all zeros + v_preds = self.model.V(z_preds[:-1]) # TODO(now): same info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()}) # Compute various targets with stopgrad. with torch.no_grad(): # Latent state consistency targets. - z_targets = self.model_target.encode(next_observations) + z_targets = self.model_target.encode(next_observations) # TODO(now): same # State-action value targets (or TD targets) as in eqn 3 of the FOWM. Unlike TD-MPC which uses the # learned state-action value function in conjunction with the learned policy: Q(z, π(z)), FOWM # uses a learned state value function: V(z). This means the TD targets only depend on in-sample # actions (not actions estimated by π). # Note: Here we do not use self.model_target, but self.model. This is to follow the original code # and the FOWM paper. - q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations)) + q_targets = reward + self.config.discount * self.model.V( + self.model.encode(next_observations) + ) # TODO(now): same # From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we # are using them to compute loss for V. - v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True) + v_targets = self.model_target.Qs( + z_preds[:-1].detach(), action, return_min=True + ) # TODO(now): zeros # Compute losses. # Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the @@ -383,7 +400,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): * ~batch["observation.state_is_pad"][1:] ) .sum(0) - .mean() + .mean() # TODO(now): same ) # Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset # rewards. @@ -397,12 +414,13 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): * ~batch["action_is_pad"] ) .sum(0) - .mean() + .mean() # TODO(now): same ) # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. q_value_loss = ( ( - F.mse_loss( + temporal_loss_coeffs + * F.mse_loss( q_preds_ensemble, einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]), reduction="none", @@ -415,7 +433,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): * ~batch["observation.state_is_pad"][1:] ) .sum(0) - .mean() + .mean() # TODO(now): same ) # Compute state value loss as in eqn 3 of FOWM. diff = v_targets - v_preds @@ -434,7 +452,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): * ~batch["action_is_pad"] ) .sum(0) - .mean() + .mean() # TODO(now): same ) # Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1. @@ -451,10 +469,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): action_preds = self.model.pi(z_preds[:-1]) # (t, b, a) # Calculate the MSE between the actions and the action predictions. # Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation - # gaussian) and sums over the action dimension. Computing the log probability amounts to multiplying - # the MSE by 0.5 and adding a constant offset (the log(2*pi) term) . Here we drop the constant offset - # as it doesn't change the optimization step, and we drop the 0.5 as we instead make a configuration - # parameter for it (see below where we compute the total loss). + # gaussian) and sums over the action dimension. Computing the (negative) log probability amounts to + # multiplying the MSE by 0.5 and adding a constant offset (the log(2*pi)/2 term, times the action + # dimension). Here we drop the constant offset as it doesn't change the optimization step, and we drop + # the 0.5 as we instead make a configuration parameter for it (see below where we compute the total + # loss). mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b) # NOTE: The original implementation does not take the sum over the temporal dimension like with the # other losses. @@ -467,7 +486,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): # `action_preds` depends on the first observation and the actions. * ~batch["observation.state_is_pad"][0] * ~batch["action_is_pad"] - ).mean() + ).mean() # TODO(now): same loss = ( self.config.consistency_coeff * consistency_loss