backup wip

This commit is contained in:
Alexander Soare 2024-04-23 11:09:50 +01:00
parent 6885eb1933
commit 69f5077a9c
7 changed files with 575 additions and 513 deletions

View File

@ -6,7 +6,7 @@ import torch
from torchvision.transforms import v2
from lerobot.common.datasets.utils import compute_stats
from lerobot.common.transforms import NormalizeTransform, Prod
from lerobot.common.transforms import Prod
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
@ -77,15 +77,16 @@ def make_dataset(
transforms = v2.Compose(
[
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
NormalizeTransform(
stats,
in_keys=[
"observation.state",
"action",
],
mode=normalization_mode,
),
# TODO(now)
Prod(in_keys=clsfunc.image_keys, prod=1 / 1.0),
# NormalizeTransform(
# stats,
# in_keys=[
# "observation.state",
# "action",
# ],
# mode=normalization_mode,
# ),
]
)

View File

@ -22,7 +22,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
def make_policy(hydra_cfg: DictConfig):
if hydra_cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
policy = TDMPCPolicy(
hydra_cfg.policy,

View File

@ -1,11 +1,13 @@
import os
import pickle
import re
from typing import Callable
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch import distributions as pyd
from torch.distributions.utils import _standard_normal
@ -96,7 +98,9 @@ def set_requires_grad(net, value):
class TruncatedNormal(pyd.Normal):
"""Utility class implementing the truncated normal distribution."""
"""Utility class implementing the truncated normal distribution while still passing gradients through.
TODO(now): consider simplifying the hell out of this but only once you understand what self.eps is for.
"""
default_sample_shape = torch.Size()
@ -107,6 +111,8 @@ class TruncatedNormal(pyd.Normal):
self.eps = eps
def _clamp(self, x):
# TODO(now): Hm looks like this is designed to pass gradients through!
# TODO(now): Understand what this eps is for.
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
x = x - x.detach() + clamped_x.detach()
return x
@ -141,7 +147,12 @@ class Flatten(nn.Module):
return x.view(x.size(0), -1)
def enc(cfg):
def enc(cfg) -> Callable[[dict[str, Tensor] | Tensor], dict[str, Tensor] | Tensor]:
"""
Creates encoders for pixel and/or state modalities.
TODO(now): Consolidate this into just working with a dict even if there is just one modality.
"""
obs_shape = {
"rgb": (3, cfg.img_size, cfg.img_size),
"state": (cfg.state_dim,),
@ -152,6 +163,7 @@ def enc(cfg):
if cfg.modality in {"pixels", "all"}:
C = int(3 * cfg.frame_stack) # noqa: N806
pixels_enc_layers = [
# TODO(now): Leave this to the env / data loader
NormalizeImg(),
nn.Conv2d(C, cfg.num_channels, 7, stride=2),
nn.ReLU(),
@ -198,7 +210,7 @@ def enc(cfg):
return Multiplexer(nn.ModuleDict(encoders))
def mlp(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
def mlp(in_dim: int, mlp_dim: int | tuple[int, int], out_dim: int, act_fn=DEFAULT_ACT_FN):
"""Returns an MLP."""
if isinstance(mlp_dim, int):
mlp_dim = [mlp_dim, mlp_dim]
@ -214,7 +226,10 @@ def mlp(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
def dynamics(in_dim, mlp_dim, out_dim, act_fn=DEFAULT_ACT_FN):
"""Returns a dynamics network."""
"""Returns a dynamics network.
TODO(now): this needs a better name. It's also an MLP...
"""
return nn.Sequential(
mlp(in_dim, mlp_dim, out_dim, act_fn),
nn.LayerNorm(out_dim),
@ -271,7 +286,7 @@ def aug(cfg):
class ConvExt(nn.Module):
"""Auxiliary conv net accommodating high-dim input"""
"""Helper to deal with arbitrary dimensions (B, *, C, H, W) for the input images."""
def __init__(self, conv):
super().__init__()
@ -279,10 +294,13 @@ class ConvExt(nn.Module):
def forward(self, x):
if x.ndim > 4:
# x has some has shape (B, * , C, H, W) so we first flatten (B, *) into the first dim, run the
# layers, then unflatten to return the result.
batch_shape = x.shape[:-3]
out = self.conv(x.view(-1, *x.shape[-3:]))
out = out.view(*batch_shape, *out.shape[1:])
else:
# x has shape (B, C, H, W).
out = self.conv(x)
return out
@ -290,7 +308,7 @@ class ConvExt(nn.Module):
class Multiplexer(nn.Module):
"""Model multiplexer"""
def __init__(self, choices):
def __init__(self, choices: nn.ModuleDict):
super().__init__()
self.choices = choices
@ -542,6 +560,7 @@ def normalize_returns(dataset, scaling=1000):
def get_reward_normalizer(cfg, dataset):
"""
Get a reward normalizer for the dataset
TODO(now): Leave this to the dataloader/env
"""
if cfg.task.startswith("xarm"):
return lambda x: x
@ -570,6 +589,8 @@ def linear_schedule(schdl, step):
return (1.0 - mix) * init + mix * final
match = re.match(r"linear\((.+),(.+),(.+)\)", schdl)
if match:
# TODO(now): Looks like the original tdmpc code uses this with
# `horizon_schedule: linear(1, ${horizon}, 25000)`
init, final, duration = (float(g) for g in match.groups())
mix = np.clip(step / duration, 0.0, 1.0)
return (1.0 - mix) * init + mix * final

View File

@ -8,6 +8,7 @@ import einops
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
import lerobot.common.policies.tdmpc.helper as h
from lerobot.common.policies.utils import populate_queues
@ -45,12 +46,13 @@ class TOLD(nn.Module):
if hasattr(self, "_V"):
h.set_requires_grad(self._V, enable)
def encode(self, obs):
def encode(self, obs: dict[str, Tensor]) -> Tensor:
"""Encodes an observation into its latent representation."""
out = self._encoder(obs)
if isinstance(obs, dict):
# fusion
out = torch.stack([v for k, v in out.items()]).mean(dim=0)
# TODO(now): careful about the order!
out = torch.stack(list(out.values())).mean(dim=0)
return out
def next(self, z, a):
@ -63,8 +65,14 @@ class TOLD(nn.Module):
x = torch.cat([z, a], dim=-1)
return self._dynamics(x)
def pi(self, z, std=0):
"""Samples an action from the learned policy (pi)."""
def pi(self, z: Tensor, std: float = 0.0) -> float:
"""Samples an action from the learned policy (pi).
TODO(now): Explain why added noise (something to do with FOWM improvements)
Returns:
action
"""
mu = torch.tanh(self._pi(z))
if std > 0:
std = torch.ones_like(mu) * std
@ -101,6 +109,7 @@ class TDMPCPolicy(nn.Module):
self.n_obs_steps = n_obs_steps
self.n_action_steps = n_action_steps
self.device = get_safe_torch_device(device)
# TODO(now): How should this be set when loading a model for eval?
self.std = h.linear_schedule(cfg.std_schedule, 0)
self.model = TOLD(cfg)
self.model.to(self.device)
@ -140,7 +149,7 @@ class TDMPCPolicy(nn.Module):
}
@torch.no_grad()
def select_action(self, batch, step):
def select_action(self, batch: dict[str, Tensor], step):
assert "observation.image" in batch
assert "observation.state" in batch
assert len(batch) == 2
@ -154,6 +163,8 @@ class TDMPCPolicy(nn.Module):
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
# TODO(now): this shouldnt be necessary. downstream code should handle this the same as
# n_obs_step > 1
if self.n_obs_steps == 1:
# hack to remove the time dimension
for key in batch:
@ -163,11 +174,12 @@ class TDMPCPolicy(nn.Module):
actions = []
batch_size = batch["observation.image"].shape[0]
for i in range(batch_size):
# TODO(now): this looks like selecting the ith but keeping dims. Do it in a cleaner way.
# TODO(now): this looks like maybe it doesn't care about what the modality choice is?
obs = {
"rgb": batch["observation.image"][[i]],
"state": batch["observation.state"][[i]],
}
# Note: unsqueeze needed because `act` still uses non-batch logic.
action = self.act(obs, t0=t0, step=self.step)
actions.append(action)
action = torch.stack(actions)
@ -180,33 +192,36 @@ class TDMPCPolicy(nn.Module):
return action
@torch.no_grad()
def act(self, obs, t0=False, step=None):
def act(self, obs: dict[str, Tensor], t0=False, step=None):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
# TODO(now): Understand this detach.
obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach()
# TODO(now): This is for compatibility with official weights. Remove.
# obs['rgb'] = obs['rgb'] * 255
# obs_ = torch.load('/tmp/obs.pth')
# out_ = torch.load('/tmp/out.pth')
# breakpoint()
z = self.model.encode(obs)
# breakpoint()
if self.cfg.mpc:
assert step is not None
a = self.plan(z, t0=t0, step=step)
else:
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
return a
@torch.no_grad()
def estimate_value(self, z, actions, horizon):
def estimate_value(self, z: Tensor, actions: Tensor, horizon: int):
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
G, discount = 0, 1
for t in range(horizon):
# TODO(now): What is the uncertainty cost.
if self.cfg.uncertainty_cost > 0:
G -= (
discount
* self.cfg.uncertainty_cost
* self.model.Q(z, actions[t], return_type="all").std(dim=0)
)
# TODO(now): Isn't this repeating the computation of the trajectory?
# TODO(now): `next` and `next_dynamics` are poorly named.
z, reward = self.model.next(z, actions[t])
G += discount * reward
discount *= self.cfg.discount
@ -217,22 +232,27 @@ class TDMPCPolicy(nn.Module):
return G
@torch.no_grad()
def plan(self, z, step=None, t0=True):
"""
Plan next action using TD-MPC inference.
z: latent state.
step: current time step. determines e.g. planning horizon.
t0: whether current step is the first step of an episode.
"""
# during eval: eval_mode: uniform sampling and action noise is disabled during evaluation.
def plan(self, z: Tensor, step: int, t0: bool = True) -> Tensor:
"""Plan next action using TD-MPC inference.
assert step is not None
during eval: eval_mode: uniform sampling and action noise is disabled during evaluation.
Args:
z: latent state.
step: current time step. determines e.g. planning horizon. TODO(now): So this differs from training step
t0: whether current step is the first step of an episode.
"""
# Seed steps
# TODO(now): What are seed steps?
if step < self.cfg.seed_steps and self.model.training:
return torch.empty(self.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1)
# Sample policy trajectories
# TODO(now): Looks like horizon_schedule is not used in fowm. There, this ends up evaluating to
# self.cfg.horizon. On the other hand, the TDMPC code base does use `horizon_schedule: linear(1, ${horizon}, 25000)`
# Basically this causes the horizon to ramp up linearly from 1 to self.cfg.horizon. Find out why.
horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step)))
# TODO(now): What is this?
num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
if num_pi_trajs > 0:
pi_actions = torch.empty(horizon, num_pi_trajs, self.action_dim, device=self.device)
@ -246,9 +266,10 @@ class TDMPCPolicy(nn.Module):
mean = torch.zeros(horizon, self.action_dim, device=self.device)
std = self.cfg.max_std * torch.ones(horizon, self.action_dim, device=self.device)
if not t0 and hasattr(self, "_prev_mean"):
# TODO(now): I think this is a "warm start"
mean[:-1] = self._prev_mean[1:]
# Iterate CEM
# Iterate CEM (cross-entropy method)
for _ in range(self.cfg.iterations):
actions = torch.clamp(
mean.unsqueeze(1)
@ -258,6 +279,8 @@ class TDMPCPolicy(nn.Module):
1,
)
if num_pi_trajs > 0:
# TODO(now): So this is having a batch of model generated action trajectories and randomly
# generated ones?
actions = torch.cat([actions, pi_actions], dim=1)
# Compute elite actions
@ -267,6 +290,7 @@ class TDMPCPolicy(nn.Module):
# Update parameters
max_value = elite_value.max(0)[0]
# TODO(now): Perhaps this is better viewed as torch.exp(-(self.cfg.temperature * (max_value - elite_value)))?
score = torch.exp(self.cfg.temperature * (elite_value - max_value))
score /= score.sum(0)
_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
@ -277,6 +301,8 @@ class TDMPCPolicy(nn.Module):
)
/ (score.sum(0) + 1e-9)
)
# TODO(now): self.std seems to be modified by update. But update has no reason to be called when
# we are loading a pretrained model.
_std = _std.clamp_(self.std, self.cfg.max_std)
mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std
@ -289,7 +315,7 @@ class TDMPCPolicy(nn.Module):
score = score.squeeze(1).cpu().numpy()
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
self._prev_mean = mean
mean, std = actions[0], _std[0]
mean, std = actions[0], _std[0] # TODO(now): not std[0]?
a = mean
if self.model.training:
a += std * torch.randn(self.action_dim, device=std.device)

View File

@ -126,6 +126,22 @@ def eval_policy(
max_steps = env.envs[0]._max_episode_steps
progbar = trange(max_steps, desc=f"Running eval with {max_steps} steps (maximum) per rollout.")
while not done.all():
# Receive observation:
import os
from time import sleep
while True:
if not os.path.exists("/tmp/mutex.txt"):
sleep(0.01)
continue
observation = {}
observation["rgb"] = np.load("/tmp/rgb.npy")
observation["state"] = np.load("/tmp/state.npy")
observation["pixels"] = observation["rgb"].transpose(1, 2, 0)[None]
observation["agent_pos"] = observation["state"][None]
done = np.load("/tmp/done.npy")
break
# format from env keys to lerobot keys
observation = preprocess_observation(observation)
if return_episode_data:
@ -142,9 +158,21 @@ def eval_policy(
with torch.inference_mode():
action = policy.select_action(observation, step=step)
# Send action:
while True:
if not os.path.exists("/tmp/mutex.txt"):
sleep(0.01)
continue
torch.save(action[0], "/tmp/action.pth")
os.remove("/tmp/mutex.txt")
break
if done:
policy.reset()
continue
# apply inverse transform to unnormalize the action
action = postprocess_action(action, transform)
action = np.array([[0, 0, 0, 0]], dtype=np.float32)
# apply the next action
observation, reward, terminated, truncated, info = env.step(action)
@ -332,8 +360,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation.
transform = make_dataset(cfg, stats_path=stats_path).transform
# TODO(now)
transform = None
logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)

930
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,7 @@ from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.datasets.xarm import XarmDataset
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
from tests.utils import require_env