Merge pull request #6 from Cadene/user/rcadene/2024_03_04_diffusion

Make diffusion work
This commit is contained in:
Remi 2024-03-04 18:30:40 +01:00 committed by GitHub
commit e990f3e148
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 276 additions and 142 deletions

View File

@ -69,7 +69,7 @@ def make_offline_buffer(cfg, sampler=None):
sampler=sampler, sampler=sampler,
batch_size=batch_size, batch_size=batch_size,
pin_memory=pin_memory, pin_memory=pin_memory,
prefetch=prefetch, prefetch=prefetch if isinstance(prefetch, int) else None,
) )
elif cfg.env.name == "pusht": elif cfg.env.name == "pusht":
offline_buffer = PushtExperienceReplay( offline_buffer = PushtExperienceReplay(
@ -79,7 +79,7 @@ def make_offline_buffer(cfg, sampler=None):
sampler=sampler, sampler=sampler,
batch_size=batch_size, batch_size=batch_size,
pin_memory=pin_memory, pin_memory=pin_memory,
prefetch=prefetch, prefetch=prefetch if isinstance(prefetch, int) else None,
) )
else: else:
raise ValueError(cfg.env.name) raise ValueError(cfg.env.name)

View File

@ -1,4 +1,5 @@
import logging import logging
import math
import os import os
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable
@ -134,20 +135,32 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
else: else:
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id)) storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
mean_std = self._compute_or_load_mean_std(storage) stats = self._compute_or_load_stats(storage)
mean_std["next", "observation", "image"] = mean_std["observation", "image"]
mean_std["next", "observation", "state"] = mean_std["observation", "state"]
transform = NormalizeTransform( transform = NormalizeTransform(
mean_std, stats,
in_keys=[ in_keys=[
("observation", "image"), # TODO(rcadene): imagenet normalization is applied inside diffusion policy
# We need to automate this for tdmpc and others
# ("observation", "image"),
("observation", "state"), ("observation", "state"),
("next", "observation", "image"), # TODO(rcadene): for tdmpc, we might want next image and state
("next", "observation", "state"), # ("next", "observation", "image"),
# ("next", "observation", "state"),
("action"), ("action"),
], ],
mode="min_max",
) )
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
transform.stats["observation", "state", "min"] = torch.tensor(
[13.456424, 32.938293], dtype=torch.float32
)
transform.stats["observation", "state", "max"] = torch.tensor(
[496.14618, 510.9579], dtype=torch.float32
)
transform.stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
transform.stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
if writer is None: if writer is None:
writer = ImmutableDatasetWriter() writer = ImmutableDatasetWriter()
if collate_fn is None: if collate_fn is None:
@ -282,24 +295,50 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
return TensorStorage(td_data.lock_()) return TensorStorage(td_data.lock_())
def _compute_mean_std(self, storage, num_batch=10, batch_size=32): def _compute_stats(self, storage, num_batch=100, batch_size=32):
rb = TensorDictReplayBuffer( rb = TensorDictReplayBuffer(
storage=storage, storage=storage,
batch_size=batch_size, batch_size=batch_size,
prefetch=True, prefetch=True,
) )
batch = rb.sample() batch = rb.sample()
image_mean = torch.zeros(batch["observation", "image"].shape[1])
image_std = torch.zeros(batch["observation", "image"].shape[1]) image_channels = batch["observation", "image"].shape[1]
state_mean = torch.zeros(batch["observation", "state"].shape[1]) image_mean = torch.zeros(image_channels)
state_std = torch.zeros(batch["observation", "state"].shape[1]) image_std = torch.zeros(image_channels)
action_mean = torch.zeros(batch["action"].shape[1]) image_max = torch.tensor([-math.inf] * image_channels)
action_std = torch.zeros(batch["action"].shape[1]) image_min = torch.tensor([math.inf] * image_channels)
state_channels = batch["observation", "state"].shape[1]
state_mean = torch.zeros(state_channels)
state_std = torch.zeros(state_channels)
state_max = torch.tensor([-math.inf] * state_channels)
state_min = torch.tensor([math.inf] * state_channels)
action_channels = batch["action"].shape[1]
action_mean = torch.zeros(action_channels)
action_std = torch.zeros(action_channels)
action_max = torch.tensor([-math.inf] * action_channels)
action_min = torch.tensor([math.inf] * action_channels)
for _ in tqdm.tqdm(range(num_batch)): for _ in tqdm.tqdm(range(num_batch)):
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean") image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
state_mean += batch["observation", "state"].mean(dim=0) state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean")
action_mean += batch["action"].mean(dim=0) action_mean += einops.reduce(batch["action"], "b c -> c", "mean")
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
image_max = torch.maximum(image_max, b_image_max)
image_min = torch.maximum(image_min, b_image_min)
state_max = torch.maximum(state_max, b_state_max)
state_min = torch.maximum(state_min, b_state_min)
action_max = torch.maximum(action_max, b_action_max)
action_min = torch.maximum(action_min, b_action_min)
batch = rb.sample() batch = rb.sample()
image_mean /= num_batch image_mean /= num_batch
@ -307,10 +346,26 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
action_mean /= num_batch action_mean /= num_batch
for i in tqdm.tqdm(range(num_batch)): for i in tqdm.tqdm(range(num_batch)):
image_mean_batch = einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean") b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
image_std += (image_mean_batch - image_mean) ** 2 b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean")
state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2 b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean")
action_std += (batch["action"].mean(dim=0) - action_mean) ** 2 image_std += (b_image_mean - image_mean) ** 2
state_std += (b_state_mean - state_mean) ** 2
action_std += (b_action_mean - action_mean) ** 2
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
image_max = torch.maximum(image_max, b_image_max)
image_min = torch.maximum(image_min, b_image_min)
state_max = torch.maximum(state_max, b_state_max)
state_min = torch.maximum(state_min, b_state_min)
action_max = torch.maximum(action_max, b_action_max)
action_min = torch.maximum(action_min, b_action_min)
if i < num_batch - 1: if i < num_batch - 1:
batch = rb.sample() batch = rb.sample()
@ -318,25 +373,33 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
state_std = torch.sqrt(state_std / num_batch) state_std = torch.sqrt(state_std / num_batch)
action_std = torch.sqrt(action_std / num_batch) action_std = torch.sqrt(action_std / num_batch)
mean_std = TensorDict( stats = TensorDict(
{ {
("observation", "image", "mean"): image_mean[None, :, None, None], ("observation", "image", "mean"): image_mean[None, :, None, None],
("observation", "image", "std"): image_std[None, :, None, None], ("observation", "image", "std"): image_std[None, :, None, None],
("observation", "image", "max"): image_max[None, :, None, None],
("observation", "image", "min"): image_min[None, :, None, None],
("observation", "state", "mean"): state_mean[None, :], ("observation", "state", "mean"): state_mean[None, :],
("observation", "state", "std"): state_std[None, :], ("observation", "state", "std"): state_std[None, :],
("observation", "state", "max"): state_max[None, :],
("observation", "state", "min"): state_min[None, :],
("action", "mean"): action_mean[None, :], ("action", "mean"): action_mean[None, :],
("action", "std"): action_std[None, :], ("action", "std"): action_std[None, :],
("action", "max"): action_max[None, :],
("action", "min"): action_min[None, :],
}, },
batch_size=[], batch_size=[],
) )
return mean_std stats["next", "observation", "image"] = stats["observation", "image"]
stats["next", "observation", "state"] = stats["observation", "state"]
return stats
def _compute_or_load_mean_std(self, storage) -> TensorDict: def _compute_or_load_stats(self, storage) -> TensorDict:
mean_std_path = self.root / self.dataset_id / "mean_std.pth" stats_path = self.root / self.dataset_id / "stats.pth"
if mean_std_path.exists(): if stats_path.exists():
mean_std = torch.load(mean_std_path) stats = torch.load(stats_path)
else: else:
logging.info(f"compute_mean_std and save to {mean_std_path}") logging.info(f"compute_stats and save to {stats_path}")
mean_std = self._compute_mean_std(storage) stats = self._compute_stats(storage)
torch.save(mean_std, mean_std_path) torch.save(stats, stats_path)
return mean_std return stats

View File

@ -1,7 +1,5 @@
from torchrl.envs.transforms import StepCounter, TransformedEnv from torchrl.envs.transforms import StepCounter, TransformedEnv
from lerobot.common.envs.transforms import Prod
def make_env(cfg, transform=None): def make_env(cfg, transform=None):
kwargs = { kwargs = {
@ -9,6 +7,8 @@ def make_env(cfg, transform=None):
"from_pixels": cfg.env.from_pixels, "from_pixels": cfg.env.from_pixels,
"pixels_only": cfg.env.pixels_only, "pixels_only": cfg.env.pixels_only,
"image_size": cfg.env.image_size, "image_size": cfg.env.image_size,
# TODO(rcadene): do we want a specific eval_env_seed?
"seed": cfg.seed,
} }
if cfg.env.name == "simxarm": if cfg.env.name == "simxarm":
@ -19,6 +19,8 @@ def make_env(cfg, transform=None):
elif cfg.env.name == "pusht": elif cfg.env.name == "pusht":
from lerobot.common.envs.pusht import PushtEnv from lerobot.common.envs.pusht import PushtEnv
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
clsfunc = PushtEnv clsfunc = PushtEnv
else: else:
raise ValueError(cfg.env.name) raise ValueError(cfg.env.name)
@ -28,12 +30,8 @@ def make_env(cfg, transform=None):
# limit rollout to max_steps # limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length)) env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
if cfg.env.name == "pusht":
# to ensure pusht is in [0,255] like simxarm
env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0))
if transform is not None: if transform is not None:
# useful to add mean and std normalization # useful to add normalization
env.append_transform(transform) env.append_transform(transform)
return env return env

View File

@ -1,4 +1,5 @@
import importlib import importlib
from collections import deque
from typing import Optional from typing import Optional
import torch import torch
@ -27,12 +28,16 @@ class PushtEnv(EnvBase):
image_size=None, image_size=None,
seed=1337, seed=1337,
device="cpu", device="cpu",
num_prev_obs=1,
num_prev_action=0,
): ):
super().__init__(device=device, batch_size=[]) super().__init__(device=device, batch_size=[])
self.frame_skip = frame_skip self.frame_skip = frame_skip
self.from_pixels = from_pixels self.from_pixels = from_pixels
self.pixels_only = pixels_only self.pixels_only = pixels_only
self.image_size = image_size self.image_size = image_size
self.num_prev_obs = num_prev_obs
self.num_prev_action = num_prev_action
if pixels_only: if pixels_only:
assert from_pixels assert from_pixels
@ -56,6 +61,12 @@ class PushtEnv(EnvBase):
self._make_spec() self._make_spec()
self._current_seed = self.set_seed(seed) self._current_seed = self.set_seed(seed)
if self.num_prev_obs > 0:
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
if self.num_prev_action > 0:
self._prev_action_queue = deque(maxlen=self.num_prev_action)
def render(self, mode="rgb_array", width=384, height=384): def render(self, mode="rgb_array", width=384, height=384):
if width != height: if width != height:
raise NotImplementedError() raise NotImplementedError()
@ -67,7 +78,8 @@ class PushtEnv(EnvBase):
def _format_raw_obs(self, raw_obs): def _format_raw_obs(self, raw_obs):
if self.from_pixels: if self.from_pixels:
obs = {"image": torch.from_numpy(raw_obs["image"])} image = torch.from_numpy(raw_obs["image"])
obs = {"image": image}
if not self.pixels_only: if not self.pixels_only:
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32) obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32)
@ -75,7 +87,6 @@ class PushtEnv(EnvBase):
# TODO: # TODO:
obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)} obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
obs = TensorDict(obs, batch_size=[])
return obs return obs
def _reset(self, tensordict: Optional[TensorDict] = None): def _reset(self, tensordict: Optional[TensorDict] = None):
@ -87,9 +98,25 @@ class PushtEnv(EnvBase):
raw_obs = self._env.reset() raw_obs = self._env.reset()
assert self._current_seed == self._env._seed assert self._current_seed == self._env._seed
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict( td = TensorDict(
{ {
"observation": self._format_raw_obs(raw_obs), "observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool), "done": torch.tensor([False], dtype=torch.bool),
}, },
batch_size=[], batch_size=[],
@ -100,18 +127,37 @@ class PushtEnv(EnvBase):
def _step(self, tensordict: TensorDict): def _step(self, tensordict: TensorDict):
td = tensordict td = tensordict
# remove batch dim action = td["action"].numpy()
action = td["action"].squeeze(0).numpy()
# step expects shape=(4,) so we pad if necessary # step expects shape=(4,) so we pad if necessary
# TODO(rcadene): add info["is_success"] and info["success"] ? # TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0 sum_reward = 0
for _ in range(self.frame_skip):
raw_obs, reward, done, info = self._env.step(action) if action.ndim == 1:
action = action.repeat(self.frame_skip, 1)
else:
if self.frame_skip > 1:
raise NotImplementedError()
num_action_steps = action.shape[0]
for i in range(num_action_steps):
raw_obs, reward, done, info = self._env.step(action[i])
sum_reward += reward sum_reward += reward
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"])
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict( td = TensorDict(
{ {
"observation": self._format_raw_obs(raw_obs), "observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([sum_reward], dtype=torch.float32), "reward": torch.tensor([sum_reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env # succes and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool), "done": torch.tensor([done], dtype=torch.bool),
@ -124,14 +170,22 @@ class PushtEnv(EnvBase):
def _make_spec(self): def _make_spec(self):
obs = {} obs = {}
if self.from_pixels: if self.from_pixels:
image_shape = (3, self.image_size, self.image_size)
if self.num_prev_obs > 0:
image_shape = (self.num_prev_obs, *image_shape)
obs["image"] = BoundedTensorSpec( obs["image"] = BoundedTensorSpec(
low=0, low=0,
high=1, high=1,
shape=(3, self.image_size, self.image_size), shape=image_shape,
dtype=torch.float32, dtype=torch.float32,
device=self.device, device=self.device,
) )
if not self.pixels_only: if not self.pixels_only:
state_shape = self._env.observation_space["agent_pos"].shape
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs, *state_shape)
obs["state"] = BoundedTensorSpec( obs["state"] = BoundedTensorSpec(
low=0, low=0,
high=512, high=512,
@ -141,6 +195,10 @@ class PushtEnv(EnvBase):
) )
else: else:
# TODO(rcadene): add observation_space achieved_goal and desired_goal? # TODO(rcadene): add observation_space achieved_goal and desired_goal?
state_shape = self._env.observation_space["observation"].shape
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs, *state_shape)
obs["state"] = UnboundedContinuousTensorSpec( obs["state"] = UnboundedContinuousTensorSpec(
# TODO: # TODO:
shape=self._env.observation_space["observation"].shape, shape=self._env.observation_space["observation"].shape,

View File

@ -28,11 +28,12 @@ class NormalizeTransform(Transform):
def __init__( def __init__(
self, self,
mean_std: TensorDictBase, stats: TensorDictBase,
in_keys: Sequence[NestedKey] = None, in_keys: Sequence[NestedKey] = None,
out_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None,
in_keys_inv: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None,
mode="mean_std",
): ):
if out_keys is None: if out_keys is None:
out_keys = in_keys out_keys = in_keys
@ -43,7 +44,14 @@ class NormalizeTransform(Transform):
super().__init__( super().__init__(
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
) )
self.mean_std = mean_std self.stats = stats
assert mode in ["mean_std", "min_max"]
self.mode = mode
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
# _reset is called once when the environment reset to normalize the first observation
tensordict_reset = self._call(tensordict_reset)
return tensordict_reset
@dispatch(source="in_keys", dest="out_keys") @dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
@ -54,9 +62,17 @@ class NormalizeTransform(Transform):
# TODO(rcadene): don't know how to do `inkey not in td` # TODO(rcadene): don't know how to do `inkey not in td`
if td.get(inkey, None) is None: if td.get(inkey, None) is None:
continue continue
mean = self.mean_std[inkey]["mean"] if self.mode == "mean_std":
std = self.mean_std[inkey]["std"] mean = self.stats[inkey]["mean"]
td[outkey] = (td[inkey] - mean) / (std + 1e-8) std = self.stats[inkey]["std"]
td[outkey] = (td[inkey] - mean) / (std + 1e-8)
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
# normalize to [0,1]
td[outkey] = (td[inkey] - min) / (max - min)
# normalize to [-1, 1]
td[outkey] = td[outkey] * 2 - 1
return td return td
def _inv_call(self, td: TensorDictBase) -> TensorDictBase: def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
@ -64,7 +80,13 @@ class NormalizeTransform(Transform):
# TODO(rcadene): don't know how to do `inkey not in td` # TODO(rcadene): don't know how to do `inkey not in td`
if td.get(inkey, None) is None: if td.get(inkey, None) is None:
continue continue
mean = self.mean_std[inkey]["mean"] if self.mode == "mean_std":
std = self.mean_std[inkey]["std"] mean = self.stats[inkey]["mean"]
td[outkey] = td[inkey] * std + mean std = self.stats[inkey]["std"]
td[outkey] = td[inkey] * std + mean
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
td[outkey] = (td[inkey] + 1) / 2
td[outkey] = td[outkey] * (max - min) + min
return td return td

View File

@ -1,51 +1,11 @@
import contextlib import logging
import os import os
from pathlib import Path from pathlib import Path
import numpy as np
from omegaconf import OmegaConf from omegaconf import OmegaConf
from termcolor import colored from termcolor import colored
def make_dir(dir_path):
"""Create directory if it does not already exist."""
with contextlib.suppress(OSError):
dir_path.mkdir(parents=True, exist_ok=True)
return dir_path
def print_run(cfg, reward=None):
"""Pretty-printing of run information. Call at start of training."""
prefix, color, attrs = " ", "green", ["bold"]
def limstr(s, maxlen=32):
return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s
def pprint(k, v):
print(
prefix + colored(f'{k.capitalize() + ":":<16}', color, attrs=attrs),
limstr(v),
)
kvs = [
("task", cfg.env.task),
("offline_steps", f"{cfg.offline_steps}"),
("online_steps", f"{cfg.online_steps}"),
("action_repeat", f"{cfg.env.action_repeat}"),
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
# ('actions', cfg.action_dim),
# ('experiment', cfg.exp_name),
]
if reward is not None:
kvs.append(("episode reward", colored(str(int(reward)), "white", attrs=["bold"])))
w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21
div = "-" * w
print(div)
for k, v in kvs:
pprint(k, v)
print(div)
def cfg_to_group(cfg, return_list=False): def cfg_to_group(cfg, return_list=False):
"""Return a wandb-safe group name for logging. Optionally returns group name as list.""" """Return a wandb-safe group name for logging. Optionally returns group name as list."""
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)] # lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
@ -71,13 +31,12 @@ class Logger:
self._seed = cfg.seed self._seed = cfg.seed
self._cfg = cfg self._cfg = cfg
self._eval = [] self._eval = []
print_run(cfg)
project = cfg.get("wandb", {}).get("project") project = cfg.get("wandb", {}).get("project")
entity = cfg.get("wandb", {}).get("entity") entity = cfg.get("wandb", {}).get("entity")
enable_wandb = cfg.get("wandb", {}).get("enable", False) enable_wandb = cfg.get("wandb", {}).get("enable", False)
run_offline = not enable_wandb or not project or not entity run_offline = not enable_wandb or not project or not entity
if run_offline: if run_offline:
print(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
self._wandb = None self._wandb = None
else: else:
os.environ["WANDB_SILENT"] = "true" os.environ["WANDB_SILENT"] = "true"
@ -134,7 +93,6 @@ class Logger:
self.save_buffer(buffer, identifier="buffer") self.save_buffer(buffer, identifier="buffer")
if self._wandb: if self._wandb:
self._wandb.finish() self._wandb.finish()
print_run(self._cfg, self._eval[-1][-1])
def log_dict(self, d, step, mode="train"): def log_dict(self, d, step, mode="train"):
assert mode in {"train", "eval"} assert mode in {"train", "eval"}
@ -144,5 +102,5 @@ class Logger:
def log_video(self, video, step, mode="train"): def log_video(self, video, step, mode="train"):
assert mode in {"train", "eval"} assert mode in {"train", "eval"}
wandb_video = self._wandb.Video(video, fps=self.cfg.fps, format="mp4") wandb_video = self._wandb.Video(video, fps=self._cfg.fps, format="mp4")
self._wandb.log({f"{mode}/video": wandb_video}, step=step) self._wandb.log({f"{mode}/video": wandb_video}, step=step)

View File

@ -1,20 +1,15 @@
import copy import copy
import time import time
import einops
import hydra import hydra
import torch import torch
import torch.nn as nn import torch.nn as nn
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.common.lr_scheduler import get_scheduler from diffusion_policy.model.common.lr_scheduler import get_scheduler
from diffusion_policy.model.vision.model_getter import get_resnet
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
from .multi_image_obs_encoder import MultiImageObsEncoder from .multi_image_obs_encoder import MultiImageObsEncoder
FIRST_ACTION = 0
class DiffusionPolicy(nn.Module): class DiffusionPolicy(nn.Module):
def __init__( def __init__(
@ -42,8 +37,8 @@ class DiffusionPolicy(nn.Module):
super().__init__() super().__init__()
self.cfg = cfg self.cfg = cfg
noise_scheduler = DDPMScheduler(**cfg_noise_scheduler) noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
rgb_model = get_resnet(**cfg_rgb_model) rgb_model = hydra.utils.instantiate(cfg_rgb_model)
obs_encoder = MultiImageObsEncoder( obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model, rgb_model=rgb_model,
**cfg_obs_encoder, **cfg_obs_encoder,
@ -101,20 +96,17 @@ class DiffusionPolicy(nn.Module):
# TODO(rcadene): remove unused step_count # TODO(rcadene): remove unused step_count
del step_count del step_count
# TODO(rcadene): remove unsqueeze hack... # TODO(rcadene): remove unsqueeze hack to add bsize=1
if observation["image"].ndim == 3: observation["image"] = observation["image"].unsqueeze(0)
observation["image"] = observation["image"].unsqueeze(0) observation["state"] = observation["state"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)
obs_dict = { obs_dict = {
# TODO(rcadene): hack to add temporal dim "image": observation["image"],
"image": einops.rearrange(observation["image"], "b c h w -> b 1 c h w"), "agent_pos": observation["state"],
"agent_pos": einops.rearrange(observation["state"], "b c -> b 1 c"),
} }
out = self.diffusion.predict_action(obs_dict) out = self.diffusion.predict_action(obs_dict)
# TODO(rcadene): add possibility to return >1 timestemps action = out["action"].squeeze(0)
action = out["action"].squeeze(0)[FIRST_ACTION]
return action return action
def update(self, replay_buffer, step): def update(self, replay_buffer, step):
@ -133,16 +125,36 @@ class DiffusionPolicy(nn.Module):
# (t h) ... -> t h ... # (t h) ... -> t h ...
batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous() batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous()
# |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16
# |o|o| observations: 2
# | |a|a|a|a|a|a|a|a| actions executed: 8
# |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16
# note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model
image = batch["observation", "image"]
state = batch["observation", "state"]
action = batch["action"]
assert image.shape[1] == horizon
assert state.shape[1] == horizon
assert action.shape[1] == horizon
if not (horizon == 16 and self.cfg.n_obs_steps == 2):
raise NotImplementedError()
# keep first 2 observations of the slice corresponding to t=[-1,0]
image = image[:, : self.cfg.n_obs_steps]
state = state[:, : self.cfg.n_obs_steps]
out = { out = {
"obs": { "obs": {
"image": batch["observation", "image"].to(self.device, non_blocking=True), "image": image.to(self.device, non_blocking=True),
"agent_pos": batch["observation", "state"].to(self.device, non_blocking=True), "agent_pos": state.to(self.device, non_blocking=True),
}, },
"action": batch["action"].to(self.device, non_blocking=True), "action": action.to(self.device, non_blocking=True),
} }
return out return out
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample() batch = replay_buffer.sample(batch_size)
batch = process_batch(batch, self.cfg.horizon, num_slices) batch = process_batch(batch, self.cfg.horizon, num_slices)
data_s = time.time() - start_time data_s = time.time() - start_time

View File

@ -1,7 +1,7 @@
defaults: defaults:
- _self_ - _self_
- env: simxarm - env: pusht
- policy: tdmpc - policy: diffusion
hydra: hydra:
run: run:
@ -21,6 +21,7 @@ save_buffer: false
train_steps: ??? train_steps: ???
fps: ??? fps: ???
n_action_steps: ???
env: ??? env: ???
policy: ??? policy: ???

View File

@ -13,7 +13,7 @@ shape_meta:
shape: [2] shape: [2]
horizon: 16 horizon: 16
n_obs_steps: 1 # TODO(rcadene): before 2 n_obs_steps: 2
n_action_steps: 8 n_action_steps: 8
n_latency_steps: 0 n_latency_steps: 0
dataset_obs_steps: ${n_obs_steps} dataset_obs_steps: ${n_obs_steps}
@ -21,7 +21,7 @@ past_action_visible: False
keypoint_visible_rate: 1.0 keypoint_visible_rate: 1.0
obs_as_global_cond: True obs_as_global_cond: True
eval_episodes: 50 eval_episodes: 1
eval_freq: 10000 eval_freq: 10000
save_freq: 100000 save_freq: 100000
log_freq: 250 log_freq: 250
@ -40,8 +40,8 @@ policy:
num_inference_steps: 100 num_inference_steps: 100
obs_as_global_cond: ${obs_as_global_cond} obs_as_global_cond: ${obs_as_global_cond}
# crop_shape: null # crop_shape: null
diffusion_step_embed_dim: 128 diffusion_step_embed_dim: 256 # before 128
down_dims: [512, 1024, 2048] down_dims: [256, 512, 1024] # before [512, 1024, 2048]
kernel_size: 5 kernel_size: 5
n_groups: 8 n_groups: 8
cond_predict_scale: True cond_predict_scale: True
@ -59,10 +59,10 @@ policy:
use_ema: true use_ema: true
lr_scheduler: cosine lr_scheduler: cosine
lr_warmup_steps: 500 lr_warmup_steps: 500
grad_clip_norm: 0 grad_clip_norm: 10
noise_scheduler: noise_scheduler:
# _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
num_train_timesteps: 100 num_train_timesteps: 100
beta_start: 0.0001 beta_start: 0.0001
beta_end: 0.02 beta_end: 0.02
@ -74,16 +74,16 @@ noise_scheduler:
obs_encoder: obs_encoder:
# _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder # _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder
shape_meta: ${shape_meta} shape_meta: ${shape_meta}
resize_shape: null # resize_shape: null
crop_shape: [76, 76] # crop_shape: [76, 76]
# constant center crop # constant center crop
random_crop: True # random_crop: True
use_group_norm: True use_group_norm: True
share_rgb_model: False share_rgb_model: False
imagenet_norm: False # TODO(rcadene): was set to True imagenet_norm: True
rgb_model: rgb_model:
#_target_: diffusion_policy.model.vision.model_getter.get_resnet _target_: diffusion_policy.model.vision.model_getter.get_resnet
name: resnet18 name: resnet18
weights: null weights: null

View File

@ -1,5 +1,7 @@
# @package _global_ # @package _global_
n_action_steps: 1
policy: policy:
name: tdmpc name: tdmpc

View File

@ -118,7 +118,7 @@ def eval(cfg: dict, out_dir=None):
offline_buffer = make_offline_buffer(cfg) offline_buffer = make_offline_buffer(cfg)
logging.info("make_env") logging.info("make_env")
env = make_env(cfg, transform=offline_buffer.transform) env = make_env(cfg, transform=offline_buffer._transform)
if cfg.policy.pretrained_model_path: if cfg.policy.pretrained_model_path:
policy = make_policy(cfg) policy = make_policy(cfg)
@ -137,7 +137,7 @@ def eval(cfg: dict, out_dir=None):
save_video=True, save_video=True,
video_dir=Path(out_dir) / "eval", video_dir=Path(out_dir) / "eval",
fps=cfg.env.fps, fps=cfg.env.fps,
max_steps=cfg.env.episode_length, max_steps=cfg.env.episode_length // cfg.n_action_steps,
num_episodes=cfg.eval_episodes, num_episodes=cfg.eval_episodes,
) )
print(metrics) print(metrics)

View File

@ -123,7 +123,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed) set_seed(cfg.seed)
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
logging.info("make_offline_buffer") logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(cfg) offline_buffer = make_offline_buffer(cfg)
@ -153,6 +152,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info("make_policy") logging.info("make_policy")
policy = make_policy(cfg) policy = make_policy(cfg)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
td_policy = TensorDictModule( td_policy = TensorDictModule(
policy, policy,
in_keys=["observation", "step_count"], in_keys=["observation", "step_count"],
@ -162,6 +164,16 @@ def train(cfg: dict, out_dir=None, job_name=None):
# log metrics to terminal and wandb # log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg) logger = Logger(out_dir, job_name, cfg)
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
logging.info(f"{cfg.online_steps=}")
logging.info(f"{cfg.env.action_repeat=}")
logging.info(f"{offline_buffer.num_samples=} ({format_big_number(offline_buffer.num_samples)})")
logging.info(f"{offline_buffer.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
step = 0 # number of policy update step = 0 # number of policy update
is_offline = True is_offline = True
@ -174,19 +186,23 @@ def train(cfg: dict, out_dir=None, job_name=None):
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline) log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
if step > 0 and step % cfg.eval_freq == 0: if step > 0 and step % cfg.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
eval_info, first_video = eval_policy( eval_info, first_video = eval_policy(
env, env,
td_policy, td_policy,
num_episodes=cfg.eval_episodes, num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length // cfg.n_action_steps,
return_first_video=True, return_first_video=True,
) )
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline) log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
if cfg.wandb.enable: if cfg.wandb.enable:
logger.log_video(first_video, step, mode="eval") logger.log_video(first_video, step, mode="eval")
logging.info("Resume training")
if step > 0 and cfg.save_model and step % cfg.save_freq == 0: if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
logging.info(f"Checkpoint model at step {step}") logging.info(f"Checkpoint policy at step {step}")
logger.save_model(policy, identifier=step) logger.save_model(policy, identifier=step)
logging.info("Resume training")
step += 1 step += 1
@ -200,11 +216,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
# TODO: add configurable number of rollout? (default=1) # TODO: add configurable number of rollout? (default=1)
with torch.no_grad(): with torch.no_grad():
rollout = env.rollout( rollout = env.rollout(
max_steps=cfg.env.episode_length, max_steps=cfg.env.episode_length // cfg.n_action_steps,
policy=td_policy, policy=td_policy,
auto_cast_to_device=True, auto_cast_to_device=True,
) )
assert len(rollout) <= cfg.env.episode_length assert len(rollout) <= cfg.env.episode_length // cfg.n_action_steps
# set same episode index for all time steps contained in this rollout # set same episode index for all time steps contained in this rollout
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int) rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
online_buffer.extend(rollout) online_buffer.extend(rollout)
@ -231,19 +247,23 @@ def train(cfg: dict, out_dir=None, job_name=None):
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline) log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
if step > 0 and step % cfg.eval_freq == 0: if step > 0 and step % cfg.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
eval_info, first_video = eval_policy( eval_info, first_video = eval_policy(
env, env,
td_policy, td_policy,
num_episodes=cfg.eval_episodes, num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length // cfg.n_action_steps,
return_first_video=True, return_first_video=True,
) )
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline) log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
if cfg.wandb.enable: if cfg.wandb.enable:
logger.log_video(first_video, step, mode="eval") logger.log_video(first_video, step, mode="eval")
logging.info("Resume training")
if step > 0 and cfg.save_model and step % cfg.save_freq == 0: if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
logging.info(f"Checkpoint model at step {step}") logging.info(f"Checkpoint policy at step {step}")
logger.save_model(policy, identifier=step) logger.save_model(policy, identifier=step)
logging.info("Resume training")
step += 1 step += 1
online_step += 1 online_step += 1