backup
This commit is contained in:
parent
fb202b5040
commit
918868162e
|
@ -13,7 +13,6 @@ TODO(alexander-soare): Use batch-first throughout.
|
||||||
import logging
|
import logging
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
@ -308,6 +307,17 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
|
|
||||||
|
import os
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if not os.path.exists("/tmp/mutex.txt"):
|
||||||
|
sleep(0.01)
|
||||||
|
continue
|
||||||
|
batch.update(torch.load("/tmp/batch.pth"))
|
||||||
|
os.remove("/tmp/mutex.txt")
|
||||||
|
break
|
||||||
|
|
||||||
info = {}
|
info = {}
|
||||||
|
|
||||||
# (b, t) -> (t, b)
|
# (b, t) -> (t, b)
|
||||||
|
@ -315,16 +325,16 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
if batch[key].ndim > 1:
|
if batch[key].ndim > 1:
|
||||||
batch[key] = batch[key].transpose(1, 0)
|
batch[key] = batch[key].transpose(1, 0)
|
||||||
|
|
||||||
action = batch["action"] # (t, b)
|
action = batch["action"] # (t, b, action_dim)
|
||||||
reward = batch["next.reward"] # (t,)
|
reward = batch["next.reward"] # (t, b)
|
||||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||||
|
|
||||||
# Apply random image augmentations.
|
# # Apply random image augmentations.
|
||||||
if self.config.max_random_shift_ratio > 0:
|
# if self.config.max_random_shift_ratio > 0:
|
||||||
observations["observation.image"] = flatten_forward_unflatten(
|
# observations["observation.image"] = flatten_forward_unflatten(
|
||||||
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
# partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||||
observations["observation.image"],
|
# observations["observation.image"],
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Get the current observation for predicting trajectories, and all future observations for use in
|
# Get the current observation for predicting trajectories, and all future observations for use in
|
||||||
# the latent consistency loss and TD loss.
|
# the latent consistency loss and TD loss.
|
||||||
|
@ -332,37 +342,44 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
for k in observations:
|
for k in observations:
|
||||||
current_observation[k] = observations[k][0]
|
current_observation[k] = observations[k][0]
|
||||||
next_observations[k] = observations[k][1:]
|
next_observations[k] = observations[k][1:]
|
||||||
horizon = next_observations["observation.image"].shape[0]
|
horizon, batch_size = next_observations["observation.image"].shape[:2]
|
||||||
|
|
||||||
# Run latent rollout using the latent dynamics model and policy model.
|
# Run latent rollout using the latent dynamics model and policy model.
|
||||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||||
# gives us a next `z`.
|
# gives us a next `z`.
|
||||||
batch_size = batch["index"].shape[0]
|
|
||||||
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||||
z_preds[0] = self.model.encode(current_observation)
|
z_preds[0] = self.model.encode(current_observation) # TODO(now): Same
|
||||||
reward_preds = torch.empty_like(reward, device=device)
|
reward_preds = torch.empty_like(reward, device=device)
|
||||||
for t in range(horizon):
|
for t in range(horizon):
|
||||||
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
|
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(
|
||||||
|
z_preds[t], action[t]
|
||||||
|
) # TODO(now): same
|
||||||
|
|
||||||
# Compute Q and V value predictions based on the latent rollout.
|
# Compute Q and V value predictions based on the latent rollout.
|
||||||
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch)
|
q_preds_ensemble = self.model.Qs(
|
||||||
v_preds = self.model.V(z_preds[:-1])
|
z_preds[:-1], action
|
||||||
|
) # (ensemble, horizon, batch) # TODO(now): all zeros
|
||||||
|
v_preds = self.model.V(z_preds[:-1]) # TODO(now): same
|
||||||
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
|
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
|
||||||
|
|
||||||
# Compute various targets with stopgrad.
|
# Compute various targets with stopgrad.
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Latent state consistency targets.
|
# Latent state consistency targets.
|
||||||
z_targets = self.model_target.encode(next_observations)
|
z_targets = self.model_target.encode(next_observations) # TODO(now): same
|
||||||
# State-action value targets (or TD targets) as in eqn 3 of the FOWM. Unlike TD-MPC which uses the
|
# State-action value targets (or TD targets) as in eqn 3 of the FOWM. Unlike TD-MPC which uses the
|
||||||
# learned state-action value function in conjunction with the learned policy: Q(z, π(z)), FOWM
|
# learned state-action value function in conjunction with the learned policy: Q(z, π(z)), FOWM
|
||||||
# uses a learned state value function: V(z). This means the TD targets only depend on in-sample
|
# uses a learned state value function: V(z). This means the TD targets only depend on in-sample
|
||||||
# actions (not actions estimated by π).
|
# actions (not actions estimated by π).
|
||||||
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code
|
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code
|
||||||
# and the FOWM paper.
|
# and the FOWM paper.
|
||||||
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations))
|
q_targets = reward + self.config.discount * self.model.V(
|
||||||
|
self.model.encode(next_observations)
|
||||||
|
) # TODO(now): same
|
||||||
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
|
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
|
||||||
# are using them to compute loss for V.
|
# are using them to compute loss for V.
|
||||||
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True)
|
v_targets = self.model_target.Qs(
|
||||||
|
z_preds[:-1].detach(), action, return_min=True
|
||||||
|
) # TODO(now): zeros
|
||||||
|
|
||||||
# Compute losses.
|
# Compute losses.
|
||||||
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
|
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
|
||||||
|
@ -383,7 +400,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
* ~batch["observation.state_is_pad"][1:]
|
* ~batch["observation.state_is_pad"][1:]
|
||||||
)
|
)
|
||||||
.sum(0)
|
.sum(0)
|
||||||
.mean()
|
.mean() # TODO(now): same
|
||||||
)
|
)
|
||||||
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
|
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
|
||||||
# rewards.
|
# rewards.
|
||||||
|
@ -397,12 +414,13 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
* ~batch["action_is_pad"]
|
* ~batch["action_is_pad"]
|
||||||
)
|
)
|
||||||
.sum(0)
|
.sum(0)
|
||||||
.mean()
|
.mean() # TODO(now): same
|
||||||
)
|
)
|
||||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||||
q_value_loss = (
|
q_value_loss = (
|
||||||
(
|
(
|
||||||
F.mse_loss(
|
temporal_loss_coeffs
|
||||||
|
* F.mse_loss(
|
||||||
q_preds_ensemble,
|
q_preds_ensemble,
|
||||||
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
||||||
reduction="none",
|
reduction="none",
|
||||||
|
@ -415,7 +433,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
* ~batch["observation.state_is_pad"][1:]
|
* ~batch["observation.state_is_pad"][1:]
|
||||||
)
|
)
|
||||||
.sum(0)
|
.sum(0)
|
||||||
.mean()
|
.mean() # TODO(now): same
|
||||||
)
|
)
|
||||||
# Compute state value loss as in eqn 3 of FOWM.
|
# Compute state value loss as in eqn 3 of FOWM.
|
||||||
diff = v_targets - v_preds
|
diff = v_targets - v_preds
|
||||||
|
@ -434,7 +452,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
* ~batch["action_is_pad"]
|
* ~batch["action_is_pad"]
|
||||||
)
|
)
|
||||||
.sum(0)
|
.sum(0)
|
||||||
.mean()
|
.mean() # TODO(now): same
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
|
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
|
||||||
|
@ -451,10 +469,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
|
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
|
||||||
# Calculate the MSE between the actions and the action predictions.
|
# Calculate the MSE between the actions and the action predictions.
|
||||||
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
|
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
|
||||||
# gaussian) and sums over the action dimension. Computing the log probability amounts to multiplying
|
# gaussian) and sums over the action dimension. Computing the (negative) log probability amounts to
|
||||||
# the MSE by 0.5 and adding a constant offset (the log(2*pi) term) . Here we drop the constant offset
|
# multiplying the MSE by 0.5 and adding a constant offset (the log(2*pi)/2 term, times the action
|
||||||
# as it doesn't change the optimization step, and we drop the 0.5 as we instead make a configuration
|
# dimension). Here we drop the constant offset as it doesn't change the optimization step, and we drop
|
||||||
# parameter for it (see below where we compute the total loss).
|
# the 0.5 as we instead make a configuration parameter for it (see below where we compute the total
|
||||||
|
# loss).
|
||||||
mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b)
|
mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b)
|
||||||
# NOTE: The original implementation does not take the sum over the temporal dimension like with the
|
# NOTE: The original implementation does not take the sum over the temporal dimension like with the
|
||||||
# other losses.
|
# other losses.
|
||||||
|
@ -467,7 +486,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
# `action_preds` depends on the first observation and the actions.
|
# `action_preds` depends on the first observation and the actions.
|
||||||
* ~batch["observation.state_is_pad"][0]
|
* ~batch["observation.state_is_pad"][0]
|
||||||
* ~batch["action_is_pad"]
|
* ~batch["action_is_pad"]
|
||||||
).mean()
|
).mean() # TODO(now): same
|
||||||
|
|
||||||
loss = (
|
loss = (
|
||||||
self.config.consistency_coeff * consistency_loss
|
self.config.consistency_coeff * consistency_loss
|
||||||
|
|
Loading…
Reference in New Issue