backup wip

This commit is contained in:
Alexander Soare 2024-04-23 11:09:50 +01:00
parent 6885eb1933
commit 69f5077a9c
7 changed files with 575 additions and 513 deletions

View File

@ -6,7 +6,7 @@ import torch
from torchvision.transforms import v2 from torchvision.transforms import v2
from lerobot.common.datasets.utils import compute_stats from lerobot.common.datasets.utils import compute_stats
from lerobot.common.transforms import NormalizeTransform, Prod from lerobot.common.transforms import Prod
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
@ -77,15 +77,16 @@ def make_dataset(
transforms = v2.Compose( transforms = v2.Compose(
[ [
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), # TODO(now)
NormalizeTransform( Prod(in_keys=clsfunc.image_keys, prod=1 / 1.0),
stats, # NormalizeTransform(
in_keys=[ # stats,
"observation.state", # in_keys=[
"action", # "observation.state",
], # "action",
mode=normalization_mode, # ],
), # mode=normalization_mode,
# ),
] ]
) )

View File

@ -22,7 +22,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
def make_policy(hydra_cfg: DictConfig): def make_policy(hydra_cfg: DictConfig):
if hydra_cfg.policy.name == "tdmpc": if hydra_cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
policy = TDMPCPolicy( policy = TDMPCPolicy(
hydra_cfg.policy, hydra_cfg.policy,

View File

@ -1,11 +1,13 @@
import os import os
import pickle import pickle
import re import re
from typing import Callable
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch import distributions as pyd from torch import distributions as pyd
from torch.distributions.utils import _standard_normal from torch.distributions.utils import _standard_normal
@ -96,7 +98,9 @@ def set_requires_grad(net, value):
class TruncatedNormal(pyd.Normal): class TruncatedNormal(pyd.Normal):
"""Utility class implementing the truncated normal distribution.""" """Utility class implementing the truncated normal distribution while still passing gradients through.
TODO(now): consider simplifying the hell out of this but only once you understand what self.eps is for.
"""
default_sample_shape = torch.Size() default_sample_shape = torch.Size()
@ -107,6 +111,8 @@ class TruncatedNormal(pyd.Normal):
self.eps = eps self.eps = eps
def _clamp(self, x): def _clamp(self, x):
# TODO(now): Hm looks like this is designed to pass gradients through!
# TODO(now): Understand what this eps is for.
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
x = x - x.detach() + clamped_x.detach() x = x - x.detach() + clamped_x.detach()
return x return x
@ -141,7 +147,12 @@ class Flatten(nn.Module):
return x.view(x.size(0), -1) return x.view(x.size(0), -1)
def enc(cfg): def enc(cfg) -> Callable[[dict[str, Tensor] | Tensor], dict[str, Tensor] | Tensor]:
"""
Creates encoders for pixel and/or state modalities.
TODO(now): Consolidate this into just working with a dict even if there is just one modality.
"""
obs_shape = { obs_shape = {
"rgb": (3, cfg.img_size, cfg.img_size), "rgb": (3, cfg.img_size, cfg.img_size),
"state": (cfg.state_dim,), "state": (cfg.state_dim,),
@ -152,6 +163,7 @@ def enc(cfg):
if cfg.modality in {"pixels", "all"}: if cfg.modality in {"pixels", "all"}:
C = int(3 * cfg.frame_stack) # noqa: N806 C = int(3 * cfg.frame_stack) # noqa: N806
pixels_enc_layers = [ pixels_enc_layers = [
# TODO(now): Leave this to the env / data loader
NormalizeImg(), NormalizeImg(),
nn.Conv2d(C, cfg.num_channels, 7, stride=2), nn.Conv2d(C, cfg.num_channels, 7, stride=2),
nn.ReLU(), nn.ReLU(),
@ -198,7 +210,7 @@ def enc(cfg):
return Multiplexer(nn.ModuleDict(encoders)) return Multiplexer(nn.ModuleDict(encoders))
def mlp(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN): def mlp(in_dim: int, mlp_dim: int | tuple[int, int], out_dim: int, act_fn=DEFAULT_ACT_FN):
"""Returns an MLP.""" """Returns an MLP."""
if isinstance(mlp_dim, int): if isinstance(mlp_dim, int):
mlp_dim = [mlp_dim, mlp_dim] mlp_dim = [mlp_dim, mlp_dim]
@ -214,7 +226,10 @@ def mlp(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
def dynamics(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN): def dynamics(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
"""Returns a dynamics network.""" """Returns a dynamics network.
TODO(now): this needs a better name. It's also an MLP...
"""
return nn.Sequential( return nn.Sequential(
mlp(in_dim, mlp_dim, out_dim, act_fn), mlp(in_dim, mlp_dim, out_dim, act_fn),
nn.LayerNorm(out_dim), nn.LayerNorm(out_dim),
@ -271,7 +286,7 @@ def aug(cfg):
class ConvExt(nn.Module): class ConvExt(nn.Module):
"""Auxiliary conv net accommodating high-dim input""" """Helper to deal with arbitrary dimensions (B, *, C, H, W) for the input images."""
def __init__(self, conv): def __init__(self, conv):
super().__init__() super().__init__()
@ -279,10 +294,13 @@ class ConvExt(nn.Module):
def forward(self, x): def forward(self, x):
if x.ndim > 4: if x.ndim > 4:
# x has some has shape (B, * , C, H, W) so we first flatten (B, *) into the first dim, run the
# layers, then unflatten to return the result.
batch_shape = x.shape[:-3] batch_shape = x.shape[:-3]
out = self.conv(x.view(-1, *x.shape[-3:])) out = self.conv(x.view(-1, *x.shape[-3:]))
out = out.view(*batch_shape, *out.shape[1:]) out = out.view(*batch_shape, *out.shape[1:])
else: else:
# x has shape (B, C, H, W).
out = self.conv(x) out = self.conv(x)
return out return out
@ -290,7 +308,7 @@ class ConvExt(nn.Module):
class Multiplexer(nn.Module): class Multiplexer(nn.Module):
"""Model multiplexer""" """Model multiplexer"""
def __init__(self, choices): def __init__(self, choices: nn.ModuleDict):
super().__init__() super().__init__()
self.choices = choices self.choices = choices
@ -542,6 +560,7 @@ def normalize_returns(dataset, scaling=1000):
def get_reward_normalizer(cfg, dataset): def get_reward_normalizer(cfg, dataset):
""" """
Get a reward normalizer for the dataset Get a reward normalizer for the dataset
TODO(now): Leave this to the dataloader/env
""" """
if cfg.task.startswith("xarm"): if cfg.task.startswith("xarm"):
return lambda x: x return lambda x: x
@ -570,6 +589,8 @@ def linear_schedule(schdl, step):
return (1.0 - mix) * init + mix * final return (1.0 - mix) * init + mix * final
match = re.match(r"linear\((.+),(.+),(.+)\)", schdl) match = re.match(r"linear\((.+),(.+),(.+)\)", schdl)
if match: if match:
# TODO(now): Looks like the original tdmpc code uses this with
# `horizon_schedule: linear(1, ${horizon}, 25000)`
init, final, duration = (float(g) for g in match.groups()) init, final, duration = (float(g) for g in match.groups())
mix = np.clip(step / duration, 0.0, 1.0) mix = np.clip(step / duration, 0.0, 1.0)
return (1.0 - mix) * init + mix * final return (1.0 - mix) * init + mix * final

View File

@ -8,6 +8,7 @@ import einops
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor
import lerobot.common.policies.tdmpc.helper as h import lerobot.common.policies.tdmpc.helper as h
from lerobot.common.policies.utils import populate_queues from lerobot.common.policies.utils import populate_queues
@ -45,12 +46,13 @@ class TOLD(nn.Module):
if hasattr(self, "_V"): if hasattr(self, "_V"):
h.set_requires_grad(self._V, enable) h.set_requires_grad(self._V, enable)
def encode(self, obs): def encode(self, obs: dict[str, Tensor]) -> Tensor:
"""Encodes an observation into its latent representation.""" """Encodes an observation into its latent representation."""
out = self._encoder(obs) out = self._encoder(obs)
if isinstance(obs, dict): if isinstance(obs, dict):
# fusion # fusion
out = torch.stack([v for k, v in out.items()]).mean(dim=0) # TODO(now): careful about the order!
out = torch.stack(list(out.values())).mean(dim=0)
return out return out
def next(self, z, a): def next(self, z, a):
@ -63,8 +65,14 @@ class TOLD(nn.Module):
x = torch.cat([z, a], dim=-1) x = torch.cat([z, a], dim=-1)
return self._dynamics(x) return self._dynamics(x)
def pi(self, z, std=0): def pi(self, z: Tensor, std: float = 0.0) -> float:
"""Samples an action from the learned policy (pi).""" """Samples an action from the learned policy (pi).
TODO(now): Explain why added noise (something to do with FOWM improvements)
Returns:
action
"""
mu = torch.tanh(self._pi(z)) mu = torch.tanh(self._pi(z))
if std > 0: if std > 0:
std = torch.ones_like(mu) * std std = torch.ones_like(mu) * std
@ -101,6 +109,7 @@ class TDMPCPolicy(nn.Module):
self.n_obs_steps = n_obs_steps self.n_obs_steps = n_obs_steps
self.n_action_steps = n_action_steps self.n_action_steps = n_action_steps
self.device = get_safe_torch_device(device) self.device = get_safe_torch_device(device)
# TODO(now): How should this be set when loading a model for eval?
self.std = h.linear_schedule(cfg.std_schedule, 0) self.std = h.linear_schedule(cfg.std_schedule, 0)
self.model = TOLD(cfg) self.model = TOLD(cfg)
self.model.to(self.device) self.model.to(self.device)
@ -140,7 +149,7 @@ class TDMPCPolicy(nn.Module):
} }
@torch.no_grad() @torch.no_grad()
def select_action(self, batch, step): def select_action(self, batch: dict[str, Tensor], step):
assert "observation.image" in batch assert "observation.image" in batch
assert "observation.state" in batch assert "observation.state" in batch
assert len(batch) == 2 assert len(batch) == 2
@ -154,6 +163,8 @@ class TDMPCPolicy(nn.Module):
if len(self._queues["action"]) == 0: if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
# TODO(now): this shouldnt be necessary. downstream code should handle this the same as
# n_obs_step > 1
if self.n_obs_steps == 1: if self.n_obs_steps == 1:
# hack to remove the time dimension # hack to remove the time dimension
for key in batch: for key in batch:
@ -163,11 +174,12 @@ class TDMPCPolicy(nn.Module):
actions = [] actions = []
batch_size = batch["observation.image"].shape[0] batch_size = batch["observation.image"].shape[0]
for i in range(batch_size): for i in range(batch_size):
# TODO(now): this looks like selecting the ith but keeping dims. Do it in a cleaner way.
# TODO(now): this looks like maybe it doesn't care about what the modality choice is?
obs = { obs = {
"rgb": batch["observation.image"][[i]], "rgb": batch["observation.image"][[i]],
"state": batch["observation.state"][[i]], "state": batch["observation.state"][[i]],
} }
# Note: unsqueeze needed because `act` still uses non-batch logic.
action = self.act(obs, t0=t0, step=self.step) action = self.act(obs, t0=t0, step=self.step)
actions.append(action) actions.append(action)
action = torch.stack(actions) action = torch.stack(actions)
@ -180,33 +192,36 @@ class TDMPCPolicy(nn.Module):
return action return action
@torch.no_grad() @torch.no_grad()
def act(self, obs, t0=False, step=None): def act(self, obs: dict[str, Tensor], t0=False, step=None):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag.""" """Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
# TODO(now): Understand this detach.
obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach() obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach()
# TODO(now): This is for compatibility with official weights. Remove. # TODO(now): This is for compatibility with official weights. Remove.
# obs['rgb'] = obs['rgb'] * 255 # obs['rgb'] = obs['rgb'] * 255
# obs_ = torch.load('/tmp/obs.pth') # obs_ = torch.load('/tmp/obs.pth')
# out_ = torch.load('/tmp/out.pth') # out_ = torch.load('/tmp/out.pth')
# breakpoint()
z = self.model.encode(obs) z = self.model.encode(obs)
# breakpoint()
if self.cfg.mpc: if self.cfg.mpc:
assert step is not None
a = self.plan(z, t0=t0, step=step) a = self.plan(z, t0=t0, step=step)
else: else:
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0) a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
return a return a
@torch.no_grad() @torch.no_grad()
def estimate_value(self, z, actions, horizon): def estimate_value(self, z: Tensor, actions: Tensor, horizon: int):
"""Estimate value of a trajectory starting at latent state z and executing given actions.""" """Estimate value of a trajectory starting at latent state z and executing given actions."""
G, discount = 0, 1 G, discount = 0, 1
for t in range(horizon): for t in range(horizon):
# TODO(now): What is the uncertainty cost.
if self.cfg.uncertainty_cost > 0: if self.cfg.uncertainty_cost > 0:
G -= ( G -= (
discount discount
* self.cfg.uncertainty_cost * self.cfg.uncertainty_cost
* self.model.Q(z, actions[t], return_type="all").std(dim=0) * self.model.Q(z, actions[t], return_type="all").std(dim=0)
) )
# TODO(now): Isn't this repeating the computation of the trajectory?
# TODO(now): `next` and `next_dynamics` are poorly named.
z, reward = self.model.next(z, actions[t]) z, reward = self.model.next(z, actions[t])
G += discount * reward G += discount * reward
discount *= self.cfg.discount discount *= self.cfg.discount
@ -217,22 +232,27 @@ class TDMPCPolicy(nn.Module):
return G return G
@torch.no_grad() @torch.no_grad()
def plan(self, z, step=None, t0=True): def plan(self, z: Tensor, step: int, t0: bool = True) -> Tensor:
""" """Plan next action using TD-MPC inference.
Plan next action using TD-MPC inference.
z: latent state.
step: current time step. determines e.g. planning horizon.
t0: whether current step is the first step of an episode.
"""
# during eval: eval_mode: uniform sampling and action noise is disabled during evaluation.
assert step is not None during eval: eval_mode: uniform sampling and action noise is disabled during evaluation.
Args:
z: latent state.
step: current time step. determines e.g. planning horizon. TODO(now): So this differs from training step
t0: whether current step is the first step of an episode.
"""
# Seed steps # Seed steps
# TODO(now): What are seed steps?
if step < self.cfg.seed_steps and self.model.training: if step < self.cfg.seed_steps and self.model.training:
return torch.empty(self.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1) return torch.empty(self.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1)
# Sample policy trajectories # Sample policy trajectories
# TODO(now): Looks like horizon_schedule is not used in fowm. There, this ends up evaluating to
# self.cfg.horizon. On the other hand, the TDMPC code base does use `horizon_schedule: linear(1, ${horizon}, 25000)`
# Basically this causes the horizon to ramp up linearly from 1 to self.cfg.horizon. Find out why.
horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step))) horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step)))
# TODO(now): What is this?
num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples) num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
if num_pi_trajs > 0: if num_pi_trajs > 0:
pi_actions = torch.empty(horizon, num_pi_trajs, self.action_dim, device=self.device) pi_actions = torch.empty(horizon, num_pi_trajs, self.action_dim, device=self.device)
@ -246,9 +266,10 @@ class TDMPCPolicy(nn.Module):
mean = torch.zeros(horizon, self.action_dim, device=self.device) mean = torch.zeros(horizon, self.action_dim, device=self.device)
std = self.cfg.max_std * torch.ones(horizon, self.action_dim, device=self.device) std = self.cfg.max_std * torch.ones(horizon, self.action_dim, device=self.device)
if not t0 and hasattr(self, "_prev_mean"): if not t0 and hasattr(self, "_prev_mean"):
# TODO(now): I think this is a "warm start"
mean[:-1] = self._prev_mean[1:] mean[:-1] = self._prev_mean[1:]
# Iterate CEM # Iterate CEM (cross-entropy method)
for _ in range(self.cfg.iterations): for _ in range(self.cfg.iterations):
actions = torch.clamp( actions = torch.clamp(
mean.unsqueeze(1) mean.unsqueeze(1)
@ -258,6 +279,8 @@ class TDMPCPolicy(nn.Module):
1, 1,
) )
if num_pi_trajs > 0: if num_pi_trajs > 0:
# TODO(now): So this is having a batch of model generated action trajectories and randomly
# generated ones?
actions = torch.cat([actions, pi_actions], dim=1) actions = torch.cat([actions, pi_actions], dim=1)
# Compute elite actions # Compute elite actions
@ -267,6 +290,7 @@ class TDMPCPolicy(nn.Module):
# Update parameters # Update parameters
max_value = elite_value.max(0)[0] max_value = elite_value.max(0)[0]
# TODO(now): Perhaps this is better viewed as torch.exp(-(self.cfg.temperature * (max_value - elite_value)))?
score = torch.exp(self.cfg.temperature * (elite_value - max_value)) score = torch.exp(self.cfg.temperature * (elite_value - max_value))
score /= score.sum(0) score /= score.sum(0)
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9) _mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
@ -277,6 +301,8 @@ class TDMPCPolicy(nn.Module):
) )
/ (score.sum(0) + 1e-9) / (score.sum(0) + 1e-9)
) )
# TODO(now): self.std seems to be modified by update. But update has no reason to be called when
# we are loading a pretrained model.
_std = _std.clamp_(self.std, self.cfg.max_std) _std = _std.clamp_(self.std, self.cfg.max_std)
mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std
@ -289,7 +315,7 @@ class TDMPCPolicy(nn.Module):
score = score.squeeze(1).cpu().numpy() score = score.squeeze(1).cpu().numpy()
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)] actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
self._prev_mean = mean self._prev_mean = mean
mean, std = actions[0], _std[0] mean, std = actions[0], _std[0] # TODO(now): not std[0]?
a = mean a = mean
if self.model.training: if self.model.training:
a += std * torch.randn(self.action_dim, device=std.device) a += std * torch.randn(self.action_dim, device=std.device)

View File

@ -126,6 +126,22 @@ def eval_policy(
max_steps = env.envs[0]._max_episode_steps max_steps = env.envs[0]._max_episode_steps
progbar = trange(max_steps, desc=f"Running eval with {max_steps} steps (maximum) per rollout.") progbar = trange(max_steps, desc=f"Running eval with {max_steps} steps (maximum) per rollout.")
while not done.all(): while not done.all():
# Receive observation:
import os
from time import sleep
while True:
if not os.path.exists("/tmp/mutex.txt"):
sleep(0.01)
continue
observation = {}
observation["rgb"] = np.load("/tmp/rgb.npy")
observation["state"] = np.load("/tmp/state.npy")
observation["pixels"] = observation["rgb"].transpose(1, 2, 0)[None]
observation["agent_pos"] = observation["state"][None]
done = np.load("/tmp/done.npy")
break
# format from env keys to lerobot keys # format from env keys to lerobot keys
observation = preprocess_observation(observation) observation = preprocess_observation(observation)
if return_episode_data: if return_episode_data:
@ -142,9 +158,21 @@ def eval_policy(
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation, step=step) action = policy.select_action(observation, step=step)
# Send action:
while True:
if not os.path.exists("/tmp/mutex.txt"):
sleep(0.01)
continue
torch.save(action[0], "/tmp/action.pth")
os.remove("/tmp/mutex.txt")
break
if done:
policy.reset()
continue
# apply inverse transform to unnormalize the action # apply inverse transform to unnormalize the action
action = postprocess_action(action, transform) action = postprocess_action(action, transform)
action = np.array([[0, 0, 0, 0]], dtype=np.float32)
# apply the next action # apply the next action
observation, reward, terminated, truncated, info = env.step(action) observation, reward, terminated, truncated, info = env.step(action)
@ -332,8 +360,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
logging.info("Making transforms.") logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation. # TODO(alexander-soare): Completely decouple datasets from evaluation.
transform = make_dataset(cfg, stats_path=stats_path).transform transform = make_dataset(cfg, stats_path=stats_path).transform
# TODO(now)
transform = None
logging.info("Making environment.") logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes) env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)

930
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,7 @@ from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.datasets.xarm import XarmDataset from lerobot.common.datasets.xarm import XarmDataset
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
from tests.utils import require_env from tests.utils import require_env