backup wip
This commit is contained in:
parent
6885eb1933
commit
69f5077a9c
|
@ -6,7 +6,7 @@ import torch
|
|||
from torchvision.transforms import v2
|
||||
|
||||
from lerobot.common.datasets.utils import compute_stats
|
||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||
from lerobot.common.transforms import Prod
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
@ -77,15 +77,16 @@ def make_dataset(
|
|||
|
||||
transforms = v2.Compose(
|
||||
[
|
||||
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
NormalizeTransform(
|
||||
stats,
|
||||
in_keys=[
|
||||
"observation.state",
|
||||
"action",
|
||||
],
|
||||
mode=normalization_mode,
|
||||
),
|
||||
# TODO(now)
|
||||
Prod(in_keys=clsfunc.image_keys, prod=1 / 1.0),
|
||||
# NormalizeTransform(
|
||||
# stats,
|
||||
# in_keys=[
|
||||
# "observation.state",
|
||||
# "action",
|
||||
# ],
|
||||
# mode=normalization_mode,
|
||||
# ),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
|||
|
||||
def make_policy(hydra_cfg: DictConfig):
|
||||
if hydra_cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
|
||||
policy = TDMPCPolicy(
|
||||
hydra_cfg.policy,
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import os
|
||||
import pickle
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch import distributions as pyd
|
||||
from torch.distributions.utils import _standard_normal
|
||||
|
||||
|
@ -96,7 +98,9 @@ def set_requires_grad(net, value):
|
|||
|
||||
|
||||
class TruncatedNormal(pyd.Normal):
|
||||
"""Utility class implementing the truncated normal distribution."""
|
||||
"""Utility class implementing the truncated normal distribution while still passing gradients through.
|
||||
TODO(now): consider simplifying the hell out of this but only once you understand what self.eps is for.
|
||||
"""
|
||||
|
||||
default_sample_shape = torch.Size()
|
||||
|
||||
|
@ -107,6 +111,8 @@ class TruncatedNormal(pyd.Normal):
|
|||
self.eps = eps
|
||||
|
||||
def _clamp(self, x):
|
||||
# TODO(now): Hm looks like this is designed to pass gradients through!
|
||||
# TODO(now): Understand what this eps is for.
|
||||
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
|
||||
x = x - x.detach() + clamped_x.detach()
|
||||
return x
|
||||
|
@ -141,7 +147,12 @@ class Flatten(nn.Module):
|
|||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
def enc(cfg):
|
||||
def enc(cfg) -> Callable[[dict[str, Tensor] | Tensor], dict[str, Tensor] | Tensor]:
|
||||
"""
|
||||
Creates encoders for pixel and/or state modalities.
|
||||
TODO(now): Consolidate this into just working with a dict even if there is just one modality.
|
||||
"""
|
||||
|
||||
obs_shape = {
|
||||
"rgb": (3, cfg.img_size, cfg.img_size),
|
||||
"state": (cfg.state_dim,),
|
||||
|
@ -152,6 +163,7 @@ def enc(cfg):
|
|||
if cfg.modality in {"pixels", "all"}:
|
||||
C = int(3 * cfg.frame_stack) # noqa: N806
|
||||
pixels_enc_layers = [
|
||||
# TODO(now): Leave this to the env / data loader
|
||||
NormalizeImg(),
|
||||
nn.Conv2d(C, cfg.num_channels, 7, stride=2),
|
||||
nn.ReLU(),
|
||||
|
@ -198,7 +210,7 @@ def enc(cfg):
|
|||
return Multiplexer(nn.ModuleDict(encoders))
|
||||
|
||||
|
||||
def mlp(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
|
||||
def mlp(in_dim: int, mlp_dim: int | tuple[int, int], out_dim: int, act_fn=DEFAULT_ACT_FN):
|
||||
"""Returns an MLP."""
|
||||
if isinstance(mlp_dim, int):
|
||||
mlp_dim = [mlp_dim, mlp_dim]
|
||||
|
@ -214,7 +226,10 @@ def mlp(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
|
|||
|
||||
|
||||
def dynamics(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
|
||||
"""Returns a dynamics network."""
|
||||
"""Returns a dynamics network.
|
||||
|
||||
TODO(now): this needs a better name. It's also an MLP...
|
||||
"""
|
||||
return nn.Sequential(
|
||||
mlp(in_dim, mlp_dim, out_dim, act_fn),
|
||||
nn.LayerNorm(out_dim),
|
||||
|
@ -271,7 +286,7 @@ def aug(cfg):
|
|||
|
||||
|
||||
class ConvExt(nn.Module):
|
||||
"""Auxiliary conv net accommodating high-dim input"""
|
||||
"""Helper to deal with arbitrary dimensions (B, *, C, H, W) for the input images."""
|
||||
|
||||
def __init__(self, conv):
|
||||
super().__init__()
|
||||
|
@ -279,10 +294,13 @@ class ConvExt(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
if x.ndim > 4:
|
||||
# x has some has shape (B, * , C, H, W) so we first flatten (B, *) into the first dim, run the
|
||||
# layers, then unflatten to return the result.
|
||||
batch_shape = x.shape[:-3]
|
||||
out = self.conv(x.view(-1, *x.shape[-3:]))
|
||||
out = out.view(*batch_shape, *out.shape[1:])
|
||||
else:
|
||||
# x has shape (B, C, H, W).
|
||||
out = self.conv(x)
|
||||
return out
|
||||
|
||||
|
@ -290,7 +308,7 @@ class ConvExt(nn.Module):
|
|||
class Multiplexer(nn.Module):
|
||||
"""Model multiplexer"""
|
||||
|
||||
def __init__(self, choices):
|
||||
def __init__(self, choices: nn.ModuleDict):
|
||||
super().__init__()
|
||||
self.choices = choices
|
||||
|
||||
|
@ -542,6 +560,7 @@ def normalize_returns(dataset, scaling=1000):
|
|||
def get_reward_normalizer(cfg, dataset):
|
||||
"""
|
||||
Get a reward normalizer for the dataset
|
||||
TODO(now): Leave this to the dataloader/env
|
||||
"""
|
||||
if cfg.task.startswith("xarm"):
|
||||
return lambda x: x
|
||||
|
@ -570,6 +589,8 @@ def linear_schedule(schdl, step):
|
|||
return (1.0 - mix) * init + mix * final
|
||||
match = re.match(r"linear\((.+),(.+),(.+)\)", schdl)
|
||||
if match:
|
||||
# TODO(now): Looks like the original tdmpc code uses this with
|
||||
# `horizon_schedule: linear(1, ${horizon}, 25000)`
|
||||
init, final, duration = (float(g) for g in match.groups())
|
||||
mix = np.clip(step / duration, 0.0, 1.0)
|
||||
return (1.0 - mix) * init + mix * final
|
||||
|
|
|
@ -8,6 +8,7 @@ import einops
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
import lerobot.common.policies.tdmpc.helper as h
|
||||
from lerobot.common.policies.utils import populate_queues
|
||||
|
@ -45,12 +46,13 @@ class TOLD(nn.Module):
|
|||
if hasattr(self, "_V"):
|
||||
h.set_requires_grad(self._V, enable)
|
||||
|
||||
def encode(self, obs):
|
||||
def encode(self, obs: dict[str, Tensor]) -> Tensor:
|
||||
"""Encodes an observation into its latent representation."""
|
||||
out = self._encoder(obs)
|
||||
if isinstance(obs, dict):
|
||||
# fusion
|
||||
out = torch.stack([v for k, v in out.items()]).mean(dim=0)
|
||||
# TODO(now): careful about the order!
|
||||
out = torch.stack(list(out.values())).mean(dim=0)
|
||||
return out
|
||||
|
||||
def next(self, z, a):
|
||||
|
@ -63,8 +65,14 @@ class TOLD(nn.Module):
|
|||
x = torch.cat([z, a], dim=-1)
|
||||
return self._dynamics(x)
|
||||
|
||||
def pi(self, z, std=0):
|
||||
"""Samples an action from the learned policy (pi)."""
|
||||
def pi(self, z: Tensor, std: float = 0.0) -> float:
|
||||
"""Samples an action from the learned policy (pi).
|
||||
|
||||
TODO(now): Explain why added noise (something to do with FOWM improvements)
|
||||
|
||||
Returns:
|
||||
action
|
||||
"""
|
||||
mu = torch.tanh(self._pi(z))
|
||||
if std > 0:
|
||||
std = torch.ones_like(mu) * std
|
||||
|
@ -101,6 +109,7 @@ class TDMPCPolicy(nn.Module):
|
|||
self.n_obs_steps = n_obs_steps
|
||||
self.n_action_steps = n_action_steps
|
||||
self.device = get_safe_torch_device(device)
|
||||
# TODO(now): How should this be set when loading a model for eval?
|
||||
self.std = h.linear_schedule(cfg.std_schedule, 0)
|
||||
self.model = TOLD(cfg)
|
||||
self.model.to(self.device)
|
||||
|
@ -140,7 +149,7 @@ class TDMPCPolicy(nn.Module):
|
|||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch, step):
|
||||
def select_action(self, batch: dict[str, Tensor], step):
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
assert len(batch) == 2
|
||||
|
@ -154,6 +163,8 @@ class TDMPCPolicy(nn.Module):
|
|||
if len(self._queues["action"]) == 0:
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
|
||||
# TODO(now): this shouldnt be necessary. downstream code should handle this the same as
|
||||
# n_obs_step > 1
|
||||
if self.n_obs_steps == 1:
|
||||
# hack to remove the time dimension
|
||||
for key in batch:
|
||||
|
@ -163,11 +174,12 @@ class TDMPCPolicy(nn.Module):
|
|||
actions = []
|
||||
batch_size = batch["observation.image"].shape[0]
|
||||
for i in range(batch_size):
|
||||
# TODO(now): this looks like selecting the ith but keeping dims. Do it in a cleaner way.
|
||||
# TODO(now): this looks like maybe it doesn't care about what the modality choice is?
|
||||
obs = {
|
||||
"rgb": batch["observation.image"][[i]],
|
||||
"state": batch["observation.state"][[i]],
|
||||
}
|
||||
# Note: unsqueeze needed because `act` still uses non-batch logic.
|
||||
action = self.act(obs, t0=t0, step=self.step)
|
||||
actions.append(action)
|
||||
action = torch.stack(actions)
|
||||
|
@ -180,33 +192,36 @@ class TDMPCPolicy(nn.Module):
|
|||
return action
|
||||
|
||||
@torch.no_grad()
|
||||
def act(self, obs, t0=False, step=None):
|
||||
def act(self, obs: dict[str, Tensor], t0=False, step=None):
|
||||
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
|
||||
# TODO(now): Understand this detach.
|
||||
obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach()
|
||||
# TODO(now): This is for compatibility with official weights. Remove.
|
||||
# obs['rgb'] = obs['rgb'] * 255
|
||||
# obs_ = torch.load('/tmp/obs.pth')
|
||||
# out_ = torch.load('/tmp/out.pth')
|
||||
# breakpoint()
|
||||
z = self.model.encode(obs)
|
||||
# breakpoint()
|
||||
if self.cfg.mpc:
|
||||
assert step is not None
|
||||
a = self.plan(z, t0=t0, step=step)
|
||||
else:
|
||||
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
|
||||
return a
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_value(self, z, actions, horizon):
|
||||
def estimate_value(self, z: Tensor, actions: Tensor, horizon: int):
|
||||
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
||||
G, discount = 0, 1
|
||||
for t in range(horizon):
|
||||
# TODO(now): What is the uncertainty cost.
|
||||
if self.cfg.uncertainty_cost > 0:
|
||||
G -= (
|
||||
discount
|
||||
* self.cfg.uncertainty_cost
|
||||
* self.model.Q(z, actions[t], return_type="all").std(dim=0)
|
||||
)
|
||||
# TODO(now): Isn't this repeating the computation of the trajectory?
|
||||
# TODO(now): `next` and `next_dynamics` are poorly named.
|
||||
z, reward = self.model.next(z, actions[t])
|
||||
G += discount * reward
|
||||
discount *= self.cfg.discount
|
||||
|
@ -217,22 +232,27 @@ class TDMPCPolicy(nn.Module):
|
|||
return G
|
||||
|
||||
@torch.no_grad()
|
||||
def plan(self, z, step=None, t0=True):
|
||||
"""
|
||||
Plan next action using TD-MPC inference.
|
||||
def plan(self, z: Tensor, step: int, t0: bool = True) -> Tensor:
|
||||
"""Plan next action using TD-MPC inference.
|
||||
|
||||
during eval: eval_mode: uniform sampling and action noise is disabled during evaluation.
|
||||
|
||||
Args:
|
||||
z: latent state.
|
||||
step: current time step. determines e.g. planning horizon.
|
||||
step: current time step. determines e.g. planning horizon. TODO(now): So this differs from training step
|
||||
t0: whether current step is the first step of an episode.
|
||||
"""
|
||||
# during eval: eval_mode: uniform sampling and action noise is disabled during evaluation.
|
||||
|
||||
assert step is not None
|
||||
# Seed steps
|
||||
# TODO(now): What are seed steps?
|
||||
if step < self.cfg.seed_steps and self.model.training:
|
||||
return torch.empty(self.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1)
|
||||
|
||||
# Sample policy trajectories
|
||||
# TODO(now): Looks like horizon_schedule is not used in fowm. There, this ends up evaluating to
|
||||
# self.cfg.horizon. On the other hand, the TDMPC code base does use `horizon_schedule: linear(1, ${horizon}, 25000)`
|
||||
# Basically this causes the horizon to ramp up linearly from 1 to self.cfg.horizon. Find out why.
|
||||
horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step)))
|
||||
# TODO(now): What is this?
|
||||
num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
|
||||
if num_pi_trajs > 0:
|
||||
pi_actions = torch.empty(horizon, num_pi_trajs, self.action_dim, device=self.device)
|
||||
|
@ -246,9 +266,10 @@ class TDMPCPolicy(nn.Module):
|
|||
mean = torch.zeros(horizon, self.action_dim, device=self.device)
|
||||
std = self.cfg.max_std * torch.ones(horizon, self.action_dim, device=self.device)
|
||||
if not t0 and hasattr(self, "_prev_mean"):
|
||||
# TODO(now): I think this is a "warm start"
|
||||
mean[:-1] = self._prev_mean[1:]
|
||||
|
||||
# Iterate CEM
|
||||
# Iterate CEM (cross-entropy method)
|
||||
for _ in range(self.cfg.iterations):
|
||||
actions = torch.clamp(
|
||||
mean.unsqueeze(1)
|
||||
|
@ -258,6 +279,8 @@ class TDMPCPolicy(nn.Module):
|
|||
1,
|
||||
)
|
||||
if num_pi_trajs > 0:
|
||||
# TODO(now): So this is having a batch of model generated action trajectories and randomly
|
||||
# generated ones?
|
||||
actions = torch.cat([actions, pi_actions], dim=1)
|
||||
|
||||
# Compute elite actions
|
||||
|
@ -267,6 +290,7 @@ class TDMPCPolicy(nn.Module):
|
|||
|
||||
# Update parameters
|
||||
max_value = elite_value.max(0)[0]
|
||||
# TODO(now): Perhaps this is better viewed as torch.exp(-(self.cfg.temperature * (max_value - elite_value)))?
|
||||
score = torch.exp(self.cfg.temperature * (elite_value - max_value))
|
||||
score /= score.sum(0)
|
||||
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
|
||||
|
@ -277,6 +301,8 @@ class TDMPCPolicy(nn.Module):
|
|||
)
|
||||
/ (score.sum(0) + 1e-9)
|
||||
)
|
||||
# TODO(now): self.std seems to be modified by update. But update has no reason to be called when
|
||||
# we are loading a pretrained model.
|
||||
_std = _std.clamp_(self.std, self.cfg.max_std)
|
||||
mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std
|
||||
|
||||
|
@ -289,7 +315,7 @@ class TDMPCPolicy(nn.Module):
|
|||
score = score.squeeze(1).cpu().numpy()
|
||||
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
|
||||
self._prev_mean = mean
|
||||
mean, std = actions[0], _std[0]
|
||||
mean, std = actions[0], _std[0] # TODO(now): not std[0]?
|
||||
a = mean
|
||||
if self.model.training:
|
||||
a += std * torch.randn(self.action_dim, device=std.device)
|
||||
|
|
|
@ -126,6 +126,22 @@ def eval_policy(
|
|||
max_steps = env.envs[0]._max_episode_steps
|
||||
progbar = trange(max_steps, desc=f"Running eval with {max_steps} steps (maximum) per rollout.")
|
||||
while not done.all():
|
||||
# Receive observation:
|
||||
import os
|
||||
from time import sleep
|
||||
|
||||
while True:
|
||||
if not os.path.exists("/tmp/mutex.txt"):
|
||||
sleep(0.01)
|
||||
continue
|
||||
observation = {}
|
||||
observation["rgb"] = np.load("/tmp/rgb.npy")
|
||||
observation["state"] = np.load("/tmp/state.npy")
|
||||
observation["pixels"] = observation["rgb"].transpose(1, 2, 0)[None]
|
||||
observation["agent_pos"] = observation["state"][None]
|
||||
done = np.load("/tmp/done.npy")
|
||||
break
|
||||
|
||||
# format from env keys to lerobot keys
|
||||
observation = preprocess_observation(observation)
|
||||
if return_episode_data:
|
||||
|
@ -142,9 +158,21 @@ def eval_policy(
|
|||
with torch.inference_mode():
|
||||
action = policy.select_action(observation, step=step)
|
||||
|
||||
# Send action:
|
||||
while True:
|
||||
if not os.path.exists("/tmp/mutex.txt"):
|
||||
sleep(0.01)
|
||||
continue
|
||||
torch.save(action[0], "/tmp/action.pth")
|
||||
os.remove("/tmp/mutex.txt")
|
||||
break
|
||||
|
||||
if done:
|
||||
policy.reset()
|
||||
continue
|
||||
|
||||
# apply inverse transform to unnormalize the action
|
||||
action = postprocess_action(action, transform)
|
||||
action = np.array([[0, 0, 0, 0]], dtype=np.float32)
|
||||
|
||||
# apply the next action
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
|
@ -332,8 +360,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
|||
logging.info("Making transforms.")
|
||||
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
||||
transform = make_dataset(cfg, stats_path=stats_path).transform
|
||||
# TODO(now)
|
||||
transform = None
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -9,7 +9,7 @@ from lerobot.common.datasets.pusht import PushtDataset
|
|||
from lerobot.common.datasets.xarm import XarmDataset
|
||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
from tests.utils import require_env
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue