backup wip
This commit is contained in:
parent
6885eb1933
commit
69f5077a9c
|
@ -6,7 +6,7 @@ import torch
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import compute_stats
|
from lerobot.common.datasets.utils import compute_stats
|
||||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
from lerobot.common.transforms import Prod
|
||||||
|
|
||||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||||
|
|
||||||
|
@ -77,15 +77,16 @@ def make_dataset(
|
||||||
|
|
||||||
transforms = v2.Compose(
|
transforms = v2.Compose(
|
||||||
[
|
[
|
||||||
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
# TODO(now)
|
||||||
NormalizeTransform(
|
Prod(in_keys=clsfunc.image_keys, prod=1 / 1.0),
|
||||||
stats,
|
# NormalizeTransform(
|
||||||
in_keys=[
|
# stats,
|
||||||
"observation.state",
|
# in_keys=[
|
||||||
"action",
|
# "observation.state",
|
||||||
],
|
# "action",
|
||||||
mode=normalization_mode,
|
# ],
|
||||||
),
|
# mode=normalization_mode,
|
||||||
|
# ),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||||
|
|
||||||
def make_policy(hydra_cfg: DictConfig):
|
def make_policy(hydra_cfg: DictConfig):
|
||||||
if hydra_cfg.policy.name == "tdmpc":
|
if hydra_cfg.policy.name == "tdmpc":
|
||||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||||
|
|
||||||
policy = TDMPCPolicy(
|
policy = TDMPCPolicy(
|
||||||
hydra_cfg.policy,
|
hydra_cfg.policy,
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import re
|
import re
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
from torch import Tensor
|
||||||
from torch import distributions as pyd
|
from torch import distributions as pyd
|
||||||
from torch.distributions.utils import _standard_normal
|
from torch.distributions.utils import _standard_normal
|
||||||
|
|
||||||
|
@ -96,7 +98,9 @@ def set_requires_grad(net, value):
|
||||||
|
|
||||||
|
|
||||||
class TruncatedNormal(pyd.Normal):
|
class TruncatedNormal(pyd.Normal):
|
||||||
"""Utility class implementing the truncated normal distribution."""
|
"""Utility class implementing the truncated normal distribution while still passing gradients through.
|
||||||
|
TODO(now): consider simplifying the hell out of this but only once you understand what self.eps is for.
|
||||||
|
"""
|
||||||
|
|
||||||
default_sample_shape = torch.Size()
|
default_sample_shape = torch.Size()
|
||||||
|
|
||||||
|
@ -107,6 +111,8 @@ class TruncatedNormal(pyd.Normal):
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
def _clamp(self, x):
|
def _clamp(self, x):
|
||||||
|
# TODO(now): Hm looks like this is designed to pass gradients through!
|
||||||
|
# TODO(now): Understand what this eps is for.
|
||||||
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
|
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
|
||||||
x = x - x.detach() + clamped_x.detach()
|
x = x - x.detach() + clamped_x.detach()
|
||||||
return x
|
return x
|
||||||
|
@ -141,7 +147,12 @@ class Flatten(nn.Module):
|
||||||
return x.view(x.size(0), -1)
|
return x.view(x.size(0), -1)
|
||||||
|
|
||||||
|
|
||||||
def enc(cfg):
|
def enc(cfg) -> Callable[[dict[str, Tensor] | Tensor], dict[str, Tensor] | Tensor]:
|
||||||
|
"""
|
||||||
|
Creates encoders for pixel and/or state modalities.
|
||||||
|
TODO(now): Consolidate this into just working with a dict even if there is just one modality.
|
||||||
|
"""
|
||||||
|
|
||||||
obs_shape = {
|
obs_shape = {
|
||||||
"rgb": (3, cfg.img_size, cfg.img_size),
|
"rgb": (3, cfg.img_size, cfg.img_size),
|
||||||
"state": (cfg.state_dim,),
|
"state": (cfg.state_dim,),
|
||||||
|
@ -152,6 +163,7 @@ def enc(cfg):
|
||||||
if cfg.modality in {"pixels", "all"}:
|
if cfg.modality in {"pixels", "all"}:
|
||||||
C = int(3 * cfg.frame_stack) # noqa: N806
|
C = int(3 * cfg.frame_stack) # noqa: N806
|
||||||
pixels_enc_layers = [
|
pixels_enc_layers = [
|
||||||
|
# TODO(now): Leave this to the env / data loader
|
||||||
NormalizeImg(),
|
NormalizeImg(),
|
||||||
nn.Conv2d(C, cfg.num_channels, 7, stride=2),
|
nn.Conv2d(C, cfg.num_channels, 7, stride=2),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
|
@ -198,7 +210,7 @@ def enc(cfg):
|
||||||
return Multiplexer(nn.ModuleDict(encoders))
|
return Multiplexer(nn.ModuleDict(encoders))
|
||||||
|
|
||||||
|
|
||||||
def mlp(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
|
def mlp(in_dim: int, mlp_dim: int | tuple[int, int], out_dim: int, act_fn=DEFAULT_ACT_FN):
|
||||||
"""Returns an MLP."""
|
"""Returns an MLP."""
|
||||||
if isinstance(mlp_dim, int):
|
if isinstance(mlp_dim, int):
|
||||||
mlp_dim = [mlp_dim, mlp_dim]
|
mlp_dim = [mlp_dim, mlp_dim]
|
||||||
|
@ -214,7 +226,10 @@ def mlp(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
|
||||||
|
|
||||||
|
|
||||||
def dynamics(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
|
def dynamics(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
|
||||||
"""Returns a dynamics network."""
|
"""Returns a dynamics network.
|
||||||
|
|
||||||
|
TODO(now): this needs a better name. It's also an MLP...
|
||||||
|
"""
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
mlp(in_dim, mlp_dim, out_dim, act_fn),
|
mlp(in_dim, mlp_dim, out_dim, act_fn),
|
||||||
nn.LayerNorm(out_dim),
|
nn.LayerNorm(out_dim),
|
||||||
|
@ -271,7 +286,7 @@ def aug(cfg):
|
||||||
|
|
||||||
|
|
||||||
class ConvExt(nn.Module):
|
class ConvExt(nn.Module):
|
||||||
"""Auxiliary conv net accommodating high-dim input"""
|
"""Helper to deal with arbitrary dimensions (B, *, C, H, W) for the input images."""
|
||||||
|
|
||||||
def __init__(self, conv):
|
def __init__(self, conv):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -279,10 +294,13 @@ class ConvExt(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if x.ndim > 4:
|
if x.ndim > 4:
|
||||||
|
# x has some has shape (B, * , C, H, W) so we first flatten (B, *) into the first dim, run the
|
||||||
|
# layers, then unflatten to return the result.
|
||||||
batch_shape = x.shape[:-3]
|
batch_shape = x.shape[:-3]
|
||||||
out = self.conv(x.view(-1, *x.shape[-3:]))
|
out = self.conv(x.view(-1, *x.shape[-3:]))
|
||||||
out = out.view(*batch_shape, *out.shape[1:])
|
out = out.view(*batch_shape, *out.shape[1:])
|
||||||
else:
|
else:
|
||||||
|
# x has shape (B, C, H, W).
|
||||||
out = self.conv(x)
|
out = self.conv(x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -290,7 +308,7 @@ class ConvExt(nn.Module):
|
||||||
class Multiplexer(nn.Module):
|
class Multiplexer(nn.Module):
|
||||||
"""Model multiplexer"""
|
"""Model multiplexer"""
|
||||||
|
|
||||||
def __init__(self, choices):
|
def __init__(self, choices: nn.ModuleDict):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.choices = choices
|
self.choices = choices
|
||||||
|
|
||||||
|
@ -542,6 +560,7 @@ def normalize_returns(dataset, scaling=1000):
|
||||||
def get_reward_normalizer(cfg, dataset):
|
def get_reward_normalizer(cfg, dataset):
|
||||||
"""
|
"""
|
||||||
Get a reward normalizer for the dataset
|
Get a reward normalizer for the dataset
|
||||||
|
TODO(now): Leave this to the dataloader/env
|
||||||
"""
|
"""
|
||||||
if cfg.task.startswith("xarm"):
|
if cfg.task.startswith("xarm"):
|
||||||
return lambda x: x
|
return lambda x: x
|
||||||
|
@ -570,6 +589,8 @@ def linear_schedule(schdl, step):
|
||||||
return (1.0 - mix) * init + mix * final
|
return (1.0 - mix) * init + mix * final
|
||||||
match = re.match(r"linear\((.+),(.+),(.+)\)", schdl)
|
match = re.match(r"linear\((.+),(.+),(.+)\)", schdl)
|
||||||
if match:
|
if match:
|
||||||
|
# TODO(now): Looks like the original tdmpc code uses this with
|
||||||
|
# `horizon_schedule: linear(1, ${horizon}, 25000)`
|
||||||
init, final, duration = (float(g) for g in match.groups())
|
init, final, duration = (float(g) for g in match.groups())
|
||||||
mix = np.clip(step / duration, 0.0, 1.0)
|
mix = np.clip(step / duration, 0.0, 1.0)
|
||||||
return (1.0 - mix) * init + mix * final
|
return (1.0 - mix) * init + mix * final
|
||||||
|
|
|
@ -8,6 +8,7 @@ import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
import lerobot.common.policies.tdmpc.helper as h
|
import lerobot.common.policies.tdmpc.helper as h
|
||||||
from lerobot.common.policies.utils import populate_queues
|
from lerobot.common.policies.utils import populate_queues
|
||||||
|
@ -45,12 +46,13 @@ class TOLD(nn.Module):
|
||||||
if hasattr(self, "_V"):
|
if hasattr(self, "_V"):
|
||||||
h.set_requires_grad(self._V, enable)
|
h.set_requires_grad(self._V, enable)
|
||||||
|
|
||||||
def encode(self, obs):
|
def encode(self, obs: dict[str, Tensor]) -> Tensor:
|
||||||
"""Encodes an observation into its latent representation."""
|
"""Encodes an observation into its latent representation."""
|
||||||
out = self._encoder(obs)
|
out = self._encoder(obs)
|
||||||
if isinstance(obs, dict):
|
if isinstance(obs, dict):
|
||||||
# fusion
|
# fusion
|
||||||
out = torch.stack([v for k, v in out.items()]).mean(dim=0)
|
# TODO(now): careful about the order!
|
||||||
|
out = torch.stack(list(out.values())).mean(dim=0)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def next(self, z, a):
|
def next(self, z, a):
|
||||||
|
@ -63,8 +65,14 @@ class TOLD(nn.Module):
|
||||||
x = torch.cat([z, a], dim=-1)
|
x = torch.cat([z, a], dim=-1)
|
||||||
return self._dynamics(x)
|
return self._dynamics(x)
|
||||||
|
|
||||||
def pi(self, z, std=0):
|
def pi(self, z: Tensor, std: float = 0.0) -> float:
|
||||||
"""Samples an action from the learned policy (pi)."""
|
"""Samples an action from the learned policy (pi).
|
||||||
|
|
||||||
|
TODO(now): Explain why added noise (something to do with FOWM improvements)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
action
|
||||||
|
"""
|
||||||
mu = torch.tanh(self._pi(z))
|
mu = torch.tanh(self._pi(z))
|
||||||
if std > 0:
|
if std > 0:
|
||||||
std = torch.ones_like(mu) * std
|
std = torch.ones_like(mu) * std
|
||||||
|
@ -101,6 +109,7 @@ class TDMPCPolicy(nn.Module):
|
||||||
self.n_obs_steps = n_obs_steps
|
self.n_obs_steps = n_obs_steps
|
||||||
self.n_action_steps = n_action_steps
|
self.n_action_steps = n_action_steps
|
||||||
self.device = get_safe_torch_device(device)
|
self.device = get_safe_torch_device(device)
|
||||||
|
# TODO(now): How should this be set when loading a model for eval?
|
||||||
self.std = h.linear_schedule(cfg.std_schedule, 0)
|
self.std = h.linear_schedule(cfg.std_schedule, 0)
|
||||||
self.model = TOLD(cfg)
|
self.model = TOLD(cfg)
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
@ -140,7 +149,7 @@ class TDMPCPolicy(nn.Module):
|
||||||
}
|
}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch, step):
|
def select_action(self, batch: dict[str, Tensor], step):
|
||||||
assert "observation.image" in batch
|
assert "observation.image" in batch
|
||||||
assert "observation.state" in batch
|
assert "observation.state" in batch
|
||||||
assert len(batch) == 2
|
assert len(batch) == 2
|
||||||
|
@ -154,6 +163,8 @@ class TDMPCPolicy(nn.Module):
|
||||||
if len(self._queues["action"]) == 0:
|
if len(self._queues["action"]) == 0:
|
||||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||||
|
|
||||||
|
# TODO(now): this shouldnt be necessary. downstream code should handle this the same as
|
||||||
|
# n_obs_step > 1
|
||||||
if self.n_obs_steps == 1:
|
if self.n_obs_steps == 1:
|
||||||
# hack to remove the time dimension
|
# hack to remove the time dimension
|
||||||
for key in batch:
|
for key in batch:
|
||||||
|
@ -163,11 +174,12 @@ class TDMPCPolicy(nn.Module):
|
||||||
actions = []
|
actions = []
|
||||||
batch_size = batch["observation.image"].shape[0]
|
batch_size = batch["observation.image"].shape[0]
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
|
# TODO(now): this looks like selecting the ith but keeping dims. Do it in a cleaner way.
|
||||||
|
# TODO(now): this looks like maybe it doesn't care about what the modality choice is?
|
||||||
obs = {
|
obs = {
|
||||||
"rgb": batch["observation.image"][[i]],
|
"rgb": batch["observation.image"][[i]],
|
||||||
"state": batch["observation.state"][[i]],
|
"state": batch["observation.state"][[i]],
|
||||||
}
|
}
|
||||||
# Note: unsqueeze needed because `act` still uses non-batch logic.
|
|
||||||
action = self.act(obs, t0=t0, step=self.step)
|
action = self.act(obs, t0=t0, step=self.step)
|
||||||
actions.append(action)
|
actions.append(action)
|
||||||
action = torch.stack(actions)
|
action = torch.stack(actions)
|
||||||
|
@ -180,33 +192,36 @@ class TDMPCPolicy(nn.Module):
|
||||||
return action
|
return action
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def act(self, obs, t0=False, step=None):
|
def act(self, obs: dict[str, Tensor], t0=False, step=None):
|
||||||
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
|
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
|
||||||
|
# TODO(now): Understand this detach.
|
||||||
obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach()
|
obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach()
|
||||||
# TODO(now): This is for compatibility with official weights. Remove.
|
# TODO(now): This is for compatibility with official weights. Remove.
|
||||||
# obs['rgb'] = obs['rgb'] * 255
|
# obs['rgb'] = obs['rgb'] * 255
|
||||||
# obs_ = torch.load('/tmp/obs.pth')
|
# obs_ = torch.load('/tmp/obs.pth')
|
||||||
# out_ = torch.load('/tmp/out.pth')
|
# out_ = torch.load('/tmp/out.pth')
|
||||||
# breakpoint()
|
|
||||||
z = self.model.encode(obs)
|
z = self.model.encode(obs)
|
||||||
# breakpoint()
|
|
||||||
if self.cfg.mpc:
|
if self.cfg.mpc:
|
||||||
|
assert step is not None
|
||||||
a = self.plan(z, t0=t0, step=step)
|
a = self.plan(z, t0=t0, step=step)
|
||||||
else:
|
else:
|
||||||
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
|
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
|
||||||
return a
|
return a
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def estimate_value(self, z, actions, horizon):
|
def estimate_value(self, z: Tensor, actions: Tensor, horizon: int):
|
||||||
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
||||||
G, discount = 0, 1
|
G, discount = 0, 1
|
||||||
for t in range(horizon):
|
for t in range(horizon):
|
||||||
|
# TODO(now): What is the uncertainty cost.
|
||||||
if self.cfg.uncertainty_cost > 0:
|
if self.cfg.uncertainty_cost > 0:
|
||||||
G -= (
|
G -= (
|
||||||
discount
|
discount
|
||||||
* self.cfg.uncertainty_cost
|
* self.cfg.uncertainty_cost
|
||||||
* self.model.Q(z, actions[t], return_type="all").std(dim=0)
|
* self.model.Q(z, actions[t], return_type="all").std(dim=0)
|
||||||
)
|
)
|
||||||
|
# TODO(now): Isn't this repeating the computation of the trajectory?
|
||||||
|
# TODO(now): `next` and `next_dynamics` are poorly named.
|
||||||
z, reward = self.model.next(z, actions[t])
|
z, reward = self.model.next(z, actions[t])
|
||||||
G += discount * reward
|
G += discount * reward
|
||||||
discount *= self.cfg.discount
|
discount *= self.cfg.discount
|
||||||
|
@ -217,22 +232,27 @@ class TDMPCPolicy(nn.Module):
|
||||||
return G
|
return G
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def plan(self, z, step=None, t0=True):
|
def plan(self, z: Tensor, step: int, t0: bool = True) -> Tensor:
|
||||||
"""
|
"""Plan next action using TD-MPC inference.
|
||||||
Plan next action using TD-MPC inference.
|
|
||||||
z: latent state.
|
|
||||||
step: current time step. determines e.g. planning horizon.
|
|
||||||
t0: whether current step is the first step of an episode.
|
|
||||||
"""
|
|
||||||
# during eval: eval_mode: uniform sampling and action noise is disabled during evaluation.
|
|
||||||
|
|
||||||
assert step is not None
|
during eval: eval_mode: uniform sampling and action noise is disabled during evaluation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z: latent state.
|
||||||
|
step: current time step. determines e.g. planning horizon. TODO(now): So this differs from training step
|
||||||
|
t0: whether current step is the first step of an episode.
|
||||||
|
"""
|
||||||
# Seed steps
|
# Seed steps
|
||||||
|
# TODO(now): What are seed steps?
|
||||||
if step < self.cfg.seed_steps and self.model.training:
|
if step < self.cfg.seed_steps and self.model.training:
|
||||||
return torch.empty(self.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1)
|
return torch.empty(self.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1)
|
||||||
|
|
||||||
# Sample policy trajectories
|
# Sample policy trajectories
|
||||||
|
# TODO(now): Looks like horizon_schedule is not used in fowm. There, this ends up evaluating to
|
||||||
|
# self.cfg.horizon. On the other hand, the TDMPC code base does use `horizon_schedule: linear(1, ${horizon}, 25000)`
|
||||||
|
# Basically this causes the horizon to ramp up linearly from 1 to self.cfg.horizon. Find out why.
|
||||||
horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step)))
|
horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step)))
|
||||||
|
# TODO(now): What is this?
|
||||||
num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
|
num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
|
||||||
if num_pi_trajs > 0:
|
if num_pi_trajs > 0:
|
||||||
pi_actions = torch.empty(horizon, num_pi_trajs, self.action_dim, device=self.device)
|
pi_actions = torch.empty(horizon, num_pi_trajs, self.action_dim, device=self.device)
|
||||||
|
@ -246,9 +266,10 @@ class TDMPCPolicy(nn.Module):
|
||||||
mean = torch.zeros(horizon, self.action_dim, device=self.device)
|
mean = torch.zeros(horizon, self.action_dim, device=self.device)
|
||||||
std = self.cfg.max_std * torch.ones(horizon, self.action_dim, device=self.device)
|
std = self.cfg.max_std * torch.ones(horizon, self.action_dim, device=self.device)
|
||||||
if not t0 and hasattr(self, "_prev_mean"):
|
if not t0 and hasattr(self, "_prev_mean"):
|
||||||
|
# TODO(now): I think this is a "warm start"
|
||||||
mean[:-1] = self._prev_mean[1:]
|
mean[:-1] = self._prev_mean[1:]
|
||||||
|
|
||||||
# Iterate CEM
|
# Iterate CEM (cross-entropy method)
|
||||||
for _ in range(self.cfg.iterations):
|
for _ in range(self.cfg.iterations):
|
||||||
actions = torch.clamp(
|
actions = torch.clamp(
|
||||||
mean.unsqueeze(1)
|
mean.unsqueeze(1)
|
||||||
|
@ -258,6 +279,8 @@ class TDMPCPolicy(nn.Module):
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
if num_pi_trajs > 0:
|
if num_pi_trajs > 0:
|
||||||
|
# TODO(now): So this is having a batch of model generated action trajectories and randomly
|
||||||
|
# generated ones?
|
||||||
actions = torch.cat([actions, pi_actions], dim=1)
|
actions = torch.cat([actions, pi_actions], dim=1)
|
||||||
|
|
||||||
# Compute elite actions
|
# Compute elite actions
|
||||||
|
@ -267,6 +290,7 @@ class TDMPCPolicy(nn.Module):
|
||||||
|
|
||||||
# Update parameters
|
# Update parameters
|
||||||
max_value = elite_value.max(0)[0]
|
max_value = elite_value.max(0)[0]
|
||||||
|
# TODO(now): Perhaps this is better viewed as torch.exp(-(self.cfg.temperature * (max_value - elite_value)))?
|
||||||
score = torch.exp(self.cfg.temperature * (elite_value - max_value))
|
score = torch.exp(self.cfg.temperature * (elite_value - max_value))
|
||||||
score /= score.sum(0)
|
score /= score.sum(0)
|
||||||
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
|
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
|
||||||
|
@ -277,6 +301,8 @@ class TDMPCPolicy(nn.Module):
|
||||||
)
|
)
|
||||||
/ (score.sum(0) + 1e-9)
|
/ (score.sum(0) + 1e-9)
|
||||||
)
|
)
|
||||||
|
# TODO(now): self.std seems to be modified by update. But update has no reason to be called when
|
||||||
|
# we are loading a pretrained model.
|
||||||
_std = _std.clamp_(self.std, self.cfg.max_std)
|
_std = _std.clamp_(self.std, self.cfg.max_std)
|
||||||
mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std
|
mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std
|
||||||
|
|
||||||
|
@ -289,7 +315,7 @@ class TDMPCPolicy(nn.Module):
|
||||||
score = score.squeeze(1).cpu().numpy()
|
score = score.squeeze(1).cpu().numpy()
|
||||||
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
|
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
|
||||||
self._prev_mean = mean
|
self._prev_mean = mean
|
||||||
mean, std = actions[0], _std[0]
|
mean, std = actions[0], _std[0] # TODO(now): not std[0]?
|
||||||
a = mean
|
a = mean
|
||||||
if self.model.training:
|
if self.model.training:
|
||||||
a += std * torch.randn(self.action_dim, device=std.device)
|
a += std * torch.randn(self.action_dim, device=std.device)
|
||||||
|
|
|
@ -126,6 +126,22 @@ def eval_policy(
|
||||||
max_steps = env.envs[0]._max_episode_steps
|
max_steps = env.envs[0]._max_episode_steps
|
||||||
progbar = trange(max_steps, desc=f"Running eval with {max_steps} steps (maximum) per rollout.")
|
progbar = trange(max_steps, desc=f"Running eval with {max_steps} steps (maximum) per rollout.")
|
||||||
while not done.all():
|
while not done.all():
|
||||||
|
# Receive observation:
|
||||||
|
import os
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if not os.path.exists("/tmp/mutex.txt"):
|
||||||
|
sleep(0.01)
|
||||||
|
continue
|
||||||
|
observation = {}
|
||||||
|
observation["rgb"] = np.load("/tmp/rgb.npy")
|
||||||
|
observation["state"] = np.load("/tmp/state.npy")
|
||||||
|
observation["pixels"] = observation["rgb"].transpose(1, 2, 0)[None]
|
||||||
|
observation["agent_pos"] = observation["state"][None]
|
||||||
|
done = np.load("/tmp/done.npy")
|
||||||
|
break
|
||||||
|
|
||||||
# format from env keys to lerobot keys
|
# format from env keys to lerobot keys
|
||||||
observation = preprocess_observation(observation)
|
observation = preprocess_observation(observation)
|
||||||
if return_episode_data:
|
if return_episode_data:
|
||||||
|
@ -142,9 +158,21 @@ def eval_policy(
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
action = policy.select_action(observation, step=step)
|
action = policy.select_action(observation, step=step)
|
||||||
|
|
||||||
|
# Send action:
|
||||||
|
while True:
|
||||||
|
if not os.path.exists("/tmp/mutex.txt"):
|
||||||
|
sleep(0.01)
|
||||||
|
continue
|
||||||
|
torch.save(action[0], "/tmp/action.pth")
|
||||||
|
os.remove("/tmp/mutex.txt")
|
||||||
|
break
|
||||||
|
|
||||||
|
if done:
|
||||||
|
policy.reset()
|
||||||
|
continue
|
||||||
|
|
||||||
# apply inverse transform to unnormalize the action
|
# apply inverse transform to unnormalize the action
|
||||||
action = postprocess_action(action, transform)
|
action = postprocess_action(action, transform)
|
||||||
action = np.array([[0, 0, 0, 0]], dtype=np.float32)
|
|
||||||
|
|
||||||
# apply the next action
|
# apply the next action
|
||||||
observation, reward, terminated, truncated, info = env.step(action)
|
observation, reward, terminated, truncated, info = env.step(action)
|
||||||
|
@ -332,8 +360,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
logging.info("Making transforms.")
|
logging.info("Making transforms.")
|
||||||
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
||||||
transform = make_dataset(cfg, stats_path=stats_path).transform
|
transform = make_dataset(cfg, stats_path=stats_path).transform
|
||||||
# TODO(now)
|
|
||||||
transform = None
|
|
||||||
|
|
||||||
logging.info("Making environment.")
|
logging.info("Making environment.")
|
||||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -9,7 +9,7 @@ from lerobot.common.datasets.pusht import PushtDataset
|
||||||
from lerobot.common.datasets.xarm import XarmDataset
|
from lerobot.common.datasets.xarm import XarmDataset
|
||||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||||
from tests.utils import require_env
|
from tests.utils import require_env
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue