Add Normalize, non_blocking=True in tdmpc, tdmpc run (TODO: diffusion)

This commit is contained in:
Remi Cadene 2024-03-02 15:53:29 +00:00
parent b5a2f460ea
commit 1ae6205269
9 changed files with 183 additions and 67 deletions

View File

@ -1,3 +1,4 @@
import os
from pathlib import Path from pathlib import Path
import torch import torch
@ -6,7 +7,7 @@ from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.pusht import PushtExperienceReplay from lerobot.common.datasets.pusht import PushtExperienceReplay
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
DATA_PATH = Path("data/") DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
# TODO(rcadene): implement # TODO(rcadene): implement
@ -64,7 +65,7 @@ def make_offline_buffer(cfg, sampler=None):
# download="force", # download="force",
download=True, download=True,
streaming=False, streaming=False,
root=str(DATA_PATH), root=str(DATA_DIR),
sampler=sampler, sampler=sampler,
batch_size=batch_size, batch_size=batch_size,
pin_memory=pin_memory, pin_memory=pin_memory,
@ -74,7 +75,7 @@ def make_offline_buffer(cfg, sampler=None):
offline_buffer = PushtExperienceReplay( offline_buffer = PushtExperienceReplay(
"pusht", "pusht",
streaming=False, streaming=False,
root=DATA_PATH, root=DATA_DIR,
sampler=sampler, sampler=sampler,
batch_size=batch_size, batch_size=batch_size,
pin_memory=pin_memory, pin_memory=pin_memory,

View File

@ -1,3 +1,4 @@
import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable
@ -16,9 +17,10 @@ from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from diffusion_policy.common.replay_buffer import ReplayBuffer from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from lerobot.common.datasets import utils from lerobot.common.datasets.utils import download_and_extract_zip
from lerobot.common.envs.transforms import NormalizeTransform
# as define in env # as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage, SUCCESS_THRESHOLD = 0.95 # 95% coverage,
@ -132,29 +134,16 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
else: else:
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
# if num_slices is not None or slice_len is not None: mean_std = self._compute_or_load_mean_std(storage)
# if sampler is not None: mean_std["next", "observation", "image"] = mean_std["observation", "image"]
# raise ValueError( mean_std["next", "observation", "state"] = mean_std["observation", "state"]
# "`num_slices` and `slice_len` are exclusive with the `sampler` argument." transform = NormalizeTransform(mean_std, in_keys=[
# ) ("observation", "image"),
("observation", "state"),
# if replacement: ("next", "observation", "image"),
# if not self.shuffle: ("next", "observation", "state"),
# raise RuntimeError( ("action"),
# "shuffle=False can only be used when replacement=False." ])
# )
# sampler = SliceSampler(
# num_slices=num_slices,
# slice_len=slice_len,
# strict_length=strict_length,
# )
# else:
# sampler = SliceSamplerWithoutReplacement(
# num_slices=num_slices,
# slice_len=slice_len,
# strict_length=strict_length,
# shuffle=self.shuffle,
# )
if writer is None: if writer is None:
writer = ImmutableDatasetWriter() writer = ImmutableDatasetWriter()
@ -193,10 +182,10 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
zarr_path = (raw_dir / PUSHT_ZARR).resolve() zarr_path = (raw_dir / PUSHT_ZARR).resolve()
if not zarr_path.is_dir(): if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True) raw_dir.mkdir(parents=True, exist_ok=True)
utils.download_and_extract_zip(PUSHT_URL, raw_dir) download_and_extract_zip(PUSHT_URL, raw_dir)
# load # load
dataset_dict = ReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action']) dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs()) episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
num_episodes = dataset_dict.meta["episode_ends"].shape[0] num_episodes = dataset_dict.meta["episode_ends"].shape[0]
@ -287,3 +276,62 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
idxtd = idxtd + len(episode) idxtd = idxtd + len(episode)
return TensorStorage(td_data.lock_()) return TensorStorage(td_data.lock_())
def _compute_mean_std(self, storage, num_batch=10, batch_size=32):
rb = TensorDictReplayBuffer(
storage=storage,
batch_size=batch_size,
prefetch=True,
)
batch = rb.sample()
image_mean = torch.zeros(batch["observation", "image"].shape[1])
image_std = torch.zeros(batch["observation", "image"].shape[1])
state_mean = torch.zeros(batch["observation", "state"].shape[1])
state_std = torch.zeros(batch["observation", "state"].shape[1])
action_mean = torch.zeros(batch["action"].shape[1])
action_std = torch.zeros(batch["action"].shape[1])
for i in tqdm.tqdm(range(num_batch)):
image_mean += einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean')
state_mean += batch["observation", "state"].mean(dim=0)
action_mean += batch["action"].mean(dim=0)
batch = rb.sample()
image_mean /= num_batch
state_mean /= num_batch
action_mean /= num_batch
for i in tqdm.tqdm(range(num_batch)):
image_mean_batch = einops.reduce(batch["observation", "image"], 'b c h w -> c', reduction='mean')
image_std += (image_mean_batch - image_mean) ** 2
state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2
action_std += (batch["action"].mean(dim=0) - action_mean) ** 2
if i < num_batch - 1:
batch = rb.sample()
image_std = torch.sqrt(image_std / num_batch)
state_std = torch.sqrt(state_std / num_batch)
action_std = torch.sqrt(action_std / num_batch)
mean_std = TensorDict(
{
("observation", "image", "mean"): image_mean[None,:,None,None],
("observation", "image", "std"): image_std[None,:,None,None],
("observation", "state", "mean"): state_mean[None,:],
("observation", "state", "std"): state_std[None,:],
("action", "mean"): action_mean[None,:],
("action", "std"): action_std[None,:],
},
batch_size=[],
)
return mean_std
def _compute_or_load_mean_std(self, storage) -> TensorDict:
mean_std_path = self.root / self.dataset_id / "mean_std.pth"
if mean_std_path.exists():
mean_std = torch.load(mean_std_path)
else:
logging.info(f"compute_mean_std and save to {mean_std_path}")
mean_std = self._compute_mean_std(storage)
torch.save(mean_std, mean_std_path)
return mean_std

View File

@ -3,9 +3,11 @@ import zipfile
from pathlib import Path from pathlib import Path
import requests import requests
from tensordict import TensorDictBase
import tqdm import tqdm
def download_and_extract_zip(url: str, destination_folder: Path) -> bool: def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
print(f"downloading from {url}") print(f"downloading from {url}")
response = requests.get(url, stream=True) response = requests.get(url, stream=True)
@ -28,3 +30,4 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
return True return True
else: else:
return False return False

View File

@ -3,7 +3,7 @@ from torchrl.envs.transforms import StepCounter, TransformedEnv
from lerobot.common.envs.transforms import Prod from lerobot.common.envs.transforms import Prod
def make_env(cfg): def make_env(cfg, transform=None):
kwargs = { kwargs = {
"frame_skip": cfg.env.action_repeat, "frame_skip": cfg.env.action_repeat,
"from_pixels": cfg.env.from_pixels, "from_pixels": cfg.env.from_pixels,
@ -32,6 +32,10 @@ def make_env(cfg):
# to ensure pusht is in [0,255] like simxarm # to ensure pusht is in [0,255] like simxarm
env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0)) env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0))
if transform is not None:
# useful to add mean and std normalization
env.append_transform(transform)
return env return env

View File

@ -100,7 +100,8 @@ class PushtEnv(EnvBase):
def _step(self, tensordict: TensorDict): def _step(self, tensordict: TensorDict):
td = tensordict td = tensordict
action = td["action"].numpy() # remove batch dim
action = td["action"].squeeze(0).numpy()
# step expects shape=(4,) so we pad if necessary # step expects shape=(4,) so we pad if necessary
# TODO(rcadene): add info["is_success"] and info["success"] ? # TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0 sum_reward = 0

View File

@ -1,7 +1,10 @@
from typing import Sequence from typing import Sequence
from tensordict import TensorDictBase
from tensordict.utils import NestedKey from tensordict.utils import NestedKey
from torchrl.envs.transforms import ObservationTransform from torchrl.envs.transforms import ObservationTransform
from torchrl.envs.transforms import Transform
from tensordict.nn import dispatch
class Prod(ObservationTransform): class Prod(ObservationTransform):
@ -19,3 +22,47 @@ class Prod(ObservationTransform):
for key in self.in_keys: for key in self.in_keys:
obs_spec[key].space.high *= self.prod obs_spec[key].space.high *= self.prod
return obs_spec return obs_spec
class NormalizeTransform(Transform):
invertible = True
def __init__(self,
mean_std: TensorDictBase,
in_keys: Sequence[NestedKey] = None,
out_keys: Sequence[NestedKey] | None = None,
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
):
if out_keys is None:
out_keys = in_keys
if in_keys_inv is None:
in_keys_inv = out_keys
if out_keys_inv is None:
out_keys_inv = in_keys
super().__init__(in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv)
self.mean_std = mean_std
@dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return self._call(tensordict)
def _call(self, td: TensorDictBase) -> TensorDictBase:
for inkey, outkey in zip(self.in_keys, self.out_keys):
# TODO(rcadene): don't know how to do `inkey not in td`
if td.get(inkey, None) is None:
continue
mean = self.mean_std[inkey]["mean"]
std = self.mean_std[inkey]["std"]
td[outkey] = (td[inkey] - mean) / (std + 1e-8)
return td
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv):
# TODO(rcadene): don't know how to do `inkey not in td`
if td.get(inkey, None) is None:
continue
mean = self.mean_std[inkey]["mean"]
std = self.mean_std[inkey]["std"]
td[outkey] = td[inkey] * std + mean
return td

View File

@ -5,6 +5,7 @@ from copy import deepcopy
import einops import einops
import numpy as np import numpy as np
from tensordict import TensorDict
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -126,19 +127,30 @@ class TDMPC(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward(self, observation, step_count): def forward(self, observation, step_count):
t0 = step_count.item() == 0 t0 = step_count.item() == 0
# TODO(rcadene): remove unsqueeze hack...
if observation["image"].ndim == 3:
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)
obs = { obs = {
"rgb": observation["image"], # TODO(rcadene): remove contiguous hack...
"state": observation["state"], "rgb": observation["image"].contiguous(),
"state": observation["state"].contiguous(),
} }
return self.act(obs, t0=t0, step=self.step.item()) action = self.act(obs, t0=t0, step=self.step.item())
# TODO(rcadene): hack to postprocess action (e.g. unnormalize)
# action = action * self.action_std + self.action_mean
return action
@torch.no_grad() @torch.no_grad()
def act(self, obs, t0=False, step=None): def act(self, obs, t0=False, step=None):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag.""" """Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
if isinstance(obs, dict): if isinstance(obs, dict):
obs = {k: o.detach().unsqueeze(0) for k, o in obs.items()} obs = {k: o.detach() for k, o in obs.items()}
else: else:
obs = obs.detach().unsqueeze(0) obs = obs.detach()
z = self.model.encode(obs) z = self.model.encode(obs)
if self.cfg.mpc: if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step) a = self.plan(z, t0=t0, step=step)
@ -315,26 +327,20 @@ class TDMPC(nn.Module):
# trajectory t = 256, horizon h = 5 # trajectory t = 256, horizon h = 5
# (t h) ... -> h t ... # (t h) ... -> h t ...
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous() batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
batch = batch.to(self.device)
obs = { obs = {
"rgb": batch["observation", "image"][FIRST_FRAME].float(), "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
"state": batch["observation", "state"][FIRST_FRAME], "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
} }
action = batch["action"] action = batch["action"].to(self.device, non_blocking=True)
next_obses = { next_obses = {
"rgb": batch["next", "observation", "image"].float(), "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
"state": batch["next", "observation", "state"], "state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
} }
reward = batch["next", "reward"] reward = batch["next", "reward"].to(self.device, non_blocking=True)
# TODO(rcadene): add non_blocking=True idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
# for key in obs: weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
# obs[key] = obs[key].to(self.device, non_blocking=True)
# next_obses[key] = next_obses[key].to(self.device, non_blocking=True)
# action = action.to(self.device, non_blocking=True)
# reward = reward.to(self.device, non_blocking=True)
# TODO(rcadene): rearrange directly in offline dataset # TODO(rcadene): rearrange directly in offline dataset
if reward.ndim == 2: if reward.ndim == 2:
@ -347,9 +353,6 @@ class TDMPC(nn.Module):
# Neither does `batch["next", "terminated"]` # Neither does `batch["next", "terminated"]`
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device) done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device) mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
idxs = batch["index"][FIRST_FRAME]
weights = batch["_weight"][FIRST_FRAME, :, None]
return obs, action, next_obses, reward, mask, done, idxs, weights return obs, action, next_obses, reward, mask, done, idxs, weights
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample() batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()

View File

@ -1,3 +1,4 @@
import logging
import threading import threading
import time import time
from pathlib import Path from pathlib import Path
@ -10,6 +11,7 @@ import tqdm
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from termcolor import colored from termcolor import colored
from torchrl.envs import EnvBase from torchrl.envs import EnvBase
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
@ -112,7 +114,11 @@ def eval(cfg: dict, out_dir=None):
set_seed(cfg.seed) set_seed(cfg.seed)
print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir) print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir)
env = make_env(cfg) logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(cfg)
logging.info("make_env")
env = make_env(cfg, transform=offline_buffer.transform)
if cfg.policy.pretrained_model_path: if cfg.policy.pretrained_model_path:
policy = make_policy(cfg) policy = make_policy(cfg)

View File

@ -117,21 +117,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
assert torch.cuda.is_available() assert torch.cuda.is_available()
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed) set_seed(cfg.seed)
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}") logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
logging.info("make_env")
env = make_env(cfg)
logging.info("make_policy")
policy = make_policy(cfg)
td_policy = TensorDictModule(
policy,
in_keys=["observation", "step_count"],
out_keys=["action"],
)
logging.info("make_offline_buffer") logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(cfg) offline_buffer = make_offline_buffer(cfg)
@ -151,8 +140,22 @@ def train(cfg: dict, out_dir=None, job_name=None):
online_buffer = TensorDictReplayBuffer( online_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(100_000), storage=LazyMemmapStorage(100_000),
sampler=online_sampler, sampler=online_sampler,
transform=offline_buffer._transform,
) )
logging.info("make_env")
env = make_env(cfg, transform=offline_buffer._transform)
logging.info("make_policy")
policy = make_policy(cfg, transform=offline_buffer._transform)
td_policy = TensorDictModule(
policy,
in_keys=["observation", "step_count"],
out_keys=["action"],
)
# log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg) logger = Logger(out_dir, job_name, cfg)
step = 0 # number of policy update step = 0 # number of policy update