This commit is contained in:
Alexander Soare 2024-05-16 09:51:01 +01:00
parent fb202b5040
commit 918868162e
1 changed files with 47 additions and 28 deletions

View File

@ -13,7 +13,6 @@ TODO(alexander-soare): Use batch-first throughout.
import logging
from collections import deque
from copy import deepcopy
from functools import partial
from typing import Callable
import einops
@ -308,6 +307,17 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
import os
from time import sleep
while True:
if not os.path.exists("/tmp/mutex.txt"):
sleep(0.01)
continue
batch.update(torch.load("/tmp/batch.pth"))
os.remove("/tmp/mutex.txt")
break
info = {}
# (b, t) -> (t, b)
@ -315,16 +325,16 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
if batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
action = batch["action"] # (t, b)
reward = batch["next.reward"] # (t,)
action = batch["action"] # (t, b, action_dim)
reward = batch["next.reward"] # (t, b)
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
# Apply random image augmentations.
if self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
observations["observation.image"],
)
# # Apply random image augmentations.
# if self.config.max_random_shift_ratio > 0:
# observations["observation.image"] = flatten_forward_unflatten(
# partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
# observations["observation.image"],
# )
# Get the current observation for predicting trajectories, and all future observations for use in
# the latent consistency loss and TD loss.
@ -332,37 +342,44 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
for k in observations:
current_observation[k] = observations[k][0]
next_observations[k] = observations[k][1:]
horizon = next_observations["observation.image"].shape[0]
horizon, batch_size = next_observations["observation.image"].shape[:2]
# Run latent rollout using the latent dynamics model and policy model.
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`.
batch_size = batch["index"].shape[0]
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds[0] = self.model.encode(current_observation)
z_preds[0] = self.model.encode(current_observation) # TODO(now): Same
reward_preds = torch.empty_like(reward, device=device)
for t in range(horizon):
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(
z_preds[t], action[t]
) # TODO(now): same
# Compute Q and V value predictions based on the latent rollout.
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch)
v_preds = self.model.V(z_preds[:-1])
q_preds_ensemble = self.model.Qs(
z_preds[:-1], action
) # (ensemble, horizon, batch) # TODO(now): all zeros
v_preds = self.model.V(z_preds[:-1]) # TODO(now): same
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
# Compute various targets with stopgrad.
with torch.no_grad():
# Latent state consistency targets.
z_targets = self.model_target.encode(next_observations)
z_targets = self.model_target.encode(next_observations) # TODO(now): same
# State-action value targets (or TD targets) as in eqn 3 of the FOWM. Unlike TD-MPC which uses the
# learned state-action value function in conjunction with the learned policy: Q(z, π(z)), FOWM
# uses a learned state value function: V(z). This means the TD targets only depend on in-sample
# actions (not actions estimated by π).
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code
# and the FOWM paper.
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations))
q_targets = reward + self.config.discount * self.model.V(
self.model.encode(next_observations)
) # TODO(now): same
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
# are using them to compute loss for V.
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True)
v_targets = self.model_target.Qs(
z_preds[:-1].detach(), action, return_min=True
) # TODO(now): zeros
# Compute losses.
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
@ -383,7 +400,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
.mean()
.mean() # TODO(now): same
)
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
# rewards.
@ -397,12 +414,13 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
* ~batch["action_is_pad"]
)
.sum(0)
.mean()
.mean() # TODO(now): same
)
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
q_value_loss = (
(
F.mse_loss(
temporal_loss_coeffs
* F.mse_loss(
q_preds_ensemble,
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
reduction="none",
@ -415,7 +433,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
.mean()
.mean() # TODO(now): same
)
# Compute state value loss as in eqn 3 of FOWM.
diff = v_targets - v_preds
@ -434,7 +452,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
* ~batch["action_is_pad"]
)
.sum(0)
.mean()
.mean() # TODO(now): same
)
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
@ -451,10 +469,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
# Calculate the MSE between the actions and the action predictions.
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
# gaussian) and sums over the action dimension. Computing the log probability amounts to multiplying
# the MSE by 0.5 and adding a constant offset (the log(2*pi) term) . Here we drop the constant offset
# as it doesn't change the optimization step, and we drop the 0.5 as we instead make a configuration
# parameter for it (see below where we compute the total loss).
# gaussian) and sums over the action dimension. Computing the (negative) log probability amounts to
# multiplying the MSE by 0.5 and adding a constant offset (the log(2*pi)/2 term, times the action
# dimension). Here we drop the constant offset as it doesn't change the optimization step, and we drop
# the 0.5 as we instead make a configuration parameter for it (see below where we compute the total
# loss).
mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b)
# NOTE: The original implementation does not take the sum over the temporal dimension like with the
# other losses.
@ -467,7 +486,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# `action_preds` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
).mean()
).mean() # TODO(now): same
loss = (
self.config.consistency_coeff * consistency_loss