This commit is contained in:
Alexander Soare 2024-05-16 09:51:01 +01:00
parent fb202b5040
commit 918868162e
1 changed files with 47 additions and 28 deletions

View File

@ -13,7 +13,6 @@ TODO(alexander-soare): Use batch-first throughout.
import logging import logging
from collections import deque from collections import deque
from copy import deepcopy from copy import deepcopy
from functools import partial
from typing import Callable from typing import Callable
import einops import einops
@ -308,6 +307,17 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
import os
from time import sleep
while True:
if not os.path.exists("/tmp/mutex.txt"):
sleep(0.01)
continue
batch.update(torch.load("/tmp/batch.pth"))
os.remove("/tmp/mutex.txt")
break
info = {} info = {}
# (b, t) -> (t, b) # (b, t) -> (t, b)
@ -315,16 +325,16 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
if batch[key].ndim > 1: if batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0) batch[key] = batch[key].transpose(1, 0)
action = batch["action"] # (t, b) action = batch["action"] # (t, b, action_dim)
reward = batch["next.reward"] # (t,) reward = batch["next.reward"] # (t, b)
observations = {k: v for k, v in batch.items() if k.startswith("observation.")} observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
# Apply random image augmentations. # # Apply random image augmentations.
if self.config.max_random_shift_ratio > 0: # if self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten( # observations["observation.image"] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio), # partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
observations["observation.image"], # observations["observation.image"],
) # )
# Get the current observation for predicting trajectories, and all future observations for use in # Get the current observation for predicting trajectories, and all future observations for use in
# the latent consistency loss and TD loss. # the latent consistency loss and TD loss.
@ -332,37 +342,44 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
for k in observations: for k in observations:
current_observation[k] = observations[k][0] current_observation[k] = observations[k][0]
next_observations[k] = observations[k][1:] next_observations[k] = observations[k][1:]
horizon = next_observations["observation.image"].shape[0] horizon, batch_size = next_observations["observation.image"].shape[:2]
# Run latent rollout using the latent dynamics model and policy model. # Run latent rollout using the latent dynamics model and policy model.
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`. # gives us a next `z`.
batch_size = batch["index"].shape[0]
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device) z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds[0] = self.model.encode(current_observation) z_preds[0] = self.model.encode(current_observation) # TODO(now): Same
reward_preds = torch.empty_like(reward, device=device) reward_preds = torch.empty_like(reward, device=device)
for t in range(horizon): for t in range(horizon):
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t]) z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(
z_preds[t], action[t]
) # TODO(now): same
# Compute Q and V value predictions based on the latent rollout. # Compute Q and V value predictions based on the latent rollout.
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch) q_preds_ensemble = self.model.Qs(
v_preds = self.model.V(z_preds[:-1]) z_preds[:-1], action
) # (ensemble, horizon, batch) # TODO(now): all zeros
v_preds = self.model.V(z_preds[:-1]) # TODO(now): same
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()}) info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
# Compute various targets with stopgrad. # Compute various targets with stopgrad.
with torch.no_grad(): with torch.no_grad():
# Latent state consistency targets. # Latent state consistency targets.
z_targets = self.model_target.encode(next_observations) z_targets = self.model_target.encode(next_observations) # TODO(now): same
# State-action value targets (or TD targets) as in eqn 3 of the FOWM. Unlike TD-MPC which uses the # State-action value targets (or TD targets) as in eqn 3 of the FOWM. Unlike TD-MPC which uses the
# learned state-action value function in conjunction with the learned policy: Q(z, π(z)), FOWM # learned state-action value function in conjunction with the learned policy: Q(z, π(z)), FOWM
# uses a learned state value function: V(z). This means the TD targets only depend on in-sample # uses a learned state value function: V(z). This means the TD targets only depend on in-sample
# actions (not actions estimated by π). # actions (not actions estimated by π).
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code # Note: Here we do not use self.model_target, but self.model. This is to follow the original code
# and the FOWM paper. # and the FOWM paper.
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations)) q_targets = reward + self.config.discount * self.model.V(
self.model.encode(next_observations)
) # TODO(now): same
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we # From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
# are using them to compute loss for V. # are using them to compute loss for V.
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True) v_targets = self.model_target.Qs(
z_preds[:-1].detach(), action, return_min=True
) # TODO(now): zeros
# Compute losses. # Compute losses.
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the # Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
@ -383,7 +400,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
* ~batch["observation.state_is_pad"][1:] * ~batch["observation.state_is_pad"][1:]
) )
.sum(0) .sum(0)
.mean() .mean() # TODO(now): same
) )
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset # Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
# rewards. # rewards.
@ -397,12 +414,13 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
* ~batch["action_is_pad"] * ~batch["action_is_pad"]
) )
.sum(0) .sum(0)
.mean() .mean() # TODO(now): same
) )
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
q_value_loss = ( q_value_loss = (
( (
F.mse_loss( temporal_loss_coeffs
* F.mse_loss(
q_preds_ensemble, q_preds_ensemble,
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]), einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
reduction="none", reduction="none",
@ -415,7 +433,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
* ~batch["observation.state_is_pad"][1:] * ~batch["observation.state_is_pad"][1:]
) )
.sum(0) .sum(0)
.mean() .mean() # TODO(now): same
) )
# Compute state value loss as in eqn 3 of FOWM. # Compute state value loss as in eqn 3 of FOWM.
diff = v_targets - v_preds diff = v_targets - v_preds
@ -434,7 +452,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
* ~batch["action_is_pad"] * ~batch["action_is_pad"]
) )
.sum(0) .sum(0)
.mean() .mean() # TODO(now): same
) )
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1. # Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
@ -451,10 +469,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a) action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
# Calculate the MSE between the actions and the action predictions. # Calculate the MSE between the actions and the action predictions.
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation # Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
# gaussian) and sums over the action dimension. Computing the log probability amounts to multiplying # gaussian) and sums over the action dimension. Computing the (negative) log probability amounts to
# the MSE by 0.5 and adding a constant offset (the log(2*pi) term) . Here we drop the constant offset # multiplying the MSE by 0.5 and adding a constant offset (the log(2*pi)/2 term, times the action
# as it doesn't change the optimization step, and we drop the 0.5 as we instead make a configuration # dimension). Here we drop the constant offset as it doesn't change the optimization step, and we drop
# parameter for it (see below where we compute the total loss). # the 0.5 as we instead make a configuration parameter for it (see below where we compute the total
# loss).
mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b) mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b)
# NOTE: The original implementation does not take the sum over the temporal dimension like with the # NOTE: The original implementation does not take the sum over the temporal dimension like with the
# other losses. # other losses.
@ -467,7 +486,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# `action_preds` depends on the first observation and the actions. # `action_preds` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0] * ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"] * ~batch["action_is_pad"]
).mean() ).mean() # TODO(now): same
loss = ( loss = (
self.config.consistency_coeff * consistency_loss self.config.consistency_coeff * consistency_loss