Improve visualize_dataset, Improve AbstractReplayBuffer, Small improvements

This commit is contained in:
Remi Cadene 2024-03-06 10:14:03 +00:00
parent 2f80d71c3e
commit f95ecd66fc
7 changed files with 195 additions and 150 deletions

View File

@ -1,6 +1,5 @@
import abc import abc
import logging import logging
import math
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable
@ -50,6 +49,22 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
transform=transform, transform=transform,
) )
@property
def stats_patterns(self) -> dict:
return {
("observation", "state"): "b c -> 1 c",
("observation", "image"): "b c h w -> 1 c 1 1",
("action"): "b c -> 1 c",
}
@property
def image_keys(self) -> list:
return [("observation", "image")]
@property
def num_cameras(self) -> int:
return len(self.image_keys)
@property @property
def num_samples(self) -> int: def num_samples(self) -> int:
return len(self) return len(self)
@ -67,7 +82,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
stats = torch.load(stats_path) stats = torch.load(stats_path)
else: else:
logging.info(f"compute_stats and save to {stats_path}") logging.info(f"compute_stats and save to {stats_path}")
stats = self._compute_stats(self._storage._storage, num_batch, batch_size) stats = self._compute_stats(num_batch, batch_size)
torch.save(stats, stats_path) torch.save(stats, stats_path)
return stats return stats
@ -85,101 +100,59 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
def _is_downloaded(self) -> bool: def _is_downloaded(self) -> bool:
return self.data_dir.is_dir() return self.data_dir.is_dir()
def _compute_stats(self, storage, num_batch=100, batch_size=32): def _compute_stats(self, num_batch=100, batch_size=32):
rb = TensorDictReplayBuffer( rb = TensorDictReplayBuffer(
storage=storage, storage=self._storage,
batch_size=batch_size, batch_size=batch_size,
prefetch=True, prefetch=True,
) )
batch = rb.sample()
image_channels = batch["observation", "image"].shape[1] mean, std, max, min = {}, {}, {}, {}
image_mean = torch.zeros(image_channels)
image_std = torch.zeros(image_channels)
image_max = torch.tensor([-math.inf] * image_channels)
image_min = torch.tensor([math.inf] * image_channels)
state_channels = batch["observation", "state"].shape[1]
state_mean = torch.zeros(state_channels)
state_std = torch.zeros(state_channels)
state_max = torch.tensor([-math.inf] * state_channels)
state_min = torch.tensor([math.inf] * state_channels)
action_channels = batch["action"].shape[1]
action_mean = torch.zeros(action_channels)
action_std = torch.zeros(action_channels)
action_max = torch.tensor([-math.inf] * action_channels)
action_min = torch.tensor([math.inf] * action_channels)
# compute mean, min, max
for _ in tqdm.tqdm(range(num_batch)): for _ in tqdm.tqdm(range(num_batch)):
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean")
action_mean += einops.reduce(batch["action"], "b c -> c", "mean")
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
image_max = torch.maximum(image_max, b_image_max)
image_min = torch.maximum(image_min, b_image_min)
state_max = torch.maximum(state_max, b_state_max)
state_min = torch.maximum(state_min, b_state_min)
action_max = torch.maximum(action_max, b_action_max)
action_min = torch.maximum(action_min, b_action_min)
batch = rb.sample() batch = rb.sample()
for key, pattern in self.stats_patterns.items():
image_mean /= num_batch batch[key] = batch[key].float()
state_mean /= num_batch if key not in mean:
action_mean /= num_batch # first batch initialize mean, min, max
mean[key] = einops.reduce(batch[key], pattern, "mean")
for i in tqdm.tqdm(range(num_batch)): max[key] = einops.reduce(batch[key], pattern, "max")
b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean") min[key] = einops.reduce(batch[key], pattern, "min")
b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean") else:
b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean") mean[key] += einops.reduce(batch[key], pattern, "mean")
image_std += (b_image_mean - image_mean) ** 2 max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
state_std += (b_state_mean - state_mean) ** 2 min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
action_std += (b_action_mean - action_mean) ** 2
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
image_max = torch.maximum(image_max, b_image_max)
image_min = torch.maximum(image_min, b_image_min)
state_max = torch.maximum(state_max, b_state_max)
state_min = torch.maximum(state_min, b_state_min)
action_max = torch.maximum(action_max, b_action_max)
action_min = torch.maximum(action_min, b_action_min)
if i < num_batch - 1:
batch = rb.sample() batch = rb.sample()
image_std = torch.sqrt(image_std / num_batch) for key in self.stats_patterns:
state_std = torch.sqrt(state_std / num_batch) mean[key] /= num_batch
action_std = torch.sqrt(action_std / num_batch)
stats = TensorDict( # compute std, min, max
{ for _ in tqdm.tqdm(range(num_batch)):
("observation", "image", "mean"): image_mean[None, :, None, None], batch = rb.sample()
("observation", "image", "std"): image_std[None, :, None, None], for key, pattern in self.stats_patterns.items():
("observation", "image", "max"): image_max[None, :, None, None], batch[key] = batch[key].float()
("observation", "image", "min"): image_min[None, :, None, None], batch_mean = einops.reduce(batch[key], pattern, "mean")
("observation", "state", "mean"): state_mean[None, :], if key not in std:
("observation", "state", "std"): state_std[None, :], # first batch initialize std
("observation", "state", "max"): state_max[None, :], std[key] = (batch_mean - mean[key]) ** 2
("observation", "state", "min"): state_min[None, :], else:
("action", "mean"): action_mean[None, :], std[key] += (batch_mean - mean[key]) ** 2
("action", "std"): action_std[None, :], max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
("action", "max"): action_max[None, :], min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
("action", "min"): action_min[None, :],
}, for key in self.stats_patterns:
batch_size=[], std[key] = torch.sqrt(std[key] / num_batch)
)
stats["next", "observation", "image"] = stats["observation", "image"] stats = TensorDict({}, batch_size=[])
stats["next", "observation", "state"] = stats["observation", "state"] for key in self.stats_patterns:
stats[(*key, "mean")] = mean[key]
stats[(*key, "std")] = std[key]
stats[(*key, "max")] = max[key]
stats[(*key, "min")] = min[key]
if key[0] == "observation":
# use same stats for the next observations
stats[("next", *key)] = stats[key]
return stats return stats

View File

@ -10,7 +10,9 @@ from lerobot.common.envs.transforms import NormalizeTransform
DATA_DIR = Path(os.environ.get("DATA_DIR", "data")) DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
def make_offline_buffer(cfg, sampler=None): def make_offline_buffer(
cfg, overwrite_sampler=None, normalize=True, overwrite_batch_size=None, overwrite_prefetch=None
):
if cfg.policy.balanced_sampling: if cfg.policy.balanced_sampling:
assert cfg.online_steps > 0 assert cfg.online_steps > 0
batch_size = None batch_size = None
@ -23,9 +25,13 @@ def make_offline_buffer(cfg, sampler=None):
pin_memory = cfg.device == "cuda" pin_memory = cfg.device == "cuda"
prefetch = cfg.prefetch prefetch = cfg.prefetch
overwrite_sampler = sampler is not None if overwrite_batch_size is not None:
batch_size = overwrite_batch_size
if not overwrite_sampler: if overwrite_prefetch is not None:
prefetch = overwrite_prefetch
if overwrite_sampler is None:
# TODO(rcadene): move batch_size outside # TODO(rcadene): move batch_size outside
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size. # TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
@ -46,6 +52,8 @@ def make_offline_buffer(cfg, sampler=None):
num_slices=num_traj_per_batch, num_slices=num_traj_per_batch,
strict_length=False, strict_length=False,
) )
else:
sampler = overwrite_sampler
if cfg.env.name == "simxarm": if cfg.env.name == "simxarm":
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
@ -70,30 +78,31 @@ def make_offline_buffer(cfg, sampler=None):
prefetch=prefetch if isinstance(prefetch, int) else None, prefetch=prefetch if isinstance(prefetch, int) else None,
) )
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec if normalize:
stats = offline_buffer.compute_or_load_stats() # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
in_keys = [("observation", "state"), ("action")] stats = offline_buffer.compute_or_load_stats()
in_keys = [("observation", "state"), ("action")]
if cfg.policy == "tdmpc": if cfg.policy == "tdmpc":
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc for key in offline_buffer.image_keys:
in_keys.append(("observation", "image")) # TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
# since we use next observations in tdmpc in_keys.append(key)
in_keys.append(("next", "observation", "image")) # since we use next observations in tdmpc
in_keys.append(("next", "observation", "state")) in_keys.append(("next", *key))
in_keys.append(("next", "observation", "state"))
if cfg.policy == "diffusion" and cfg.env.name == "pusht": if cfg.policy == "diffusion" and cfg.env.name == "pusht":
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32) stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32) stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
transform = NormalizeTransform(stats, in_keys, mode="min_max") transform = NormalizeTransform(stats, in_keys, mode="min_max")
offline_buffer.set_transform(transform) offline_buffer.set_transform(transform)
if not overwrite_sampler: if not overwrite_sampler:
num_steps = len(offline_buffer) index = torch.arange(0, offline_buffer.num_frames, 1)
index = torch.arange(0, num_steps, 1)
sampler.extend(index) sampler.extend(index)
return offline_buffer return offline_buffer

View File

@ -183,8 +183,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
# last step of demonstration is considered done # last step of demonstration is considered done
done[-1] = True done[-1] = True
print("before " + """episode = TensorDict(""") ep_td = TensorDict(
episode = TensorDict(
{ {
("observation", "image"): image[:-1], ("observation", "image"): image[:-1],
("observation", "state"): agent_pos[:-1], ("observation", "state"): agent_pos[:-1],
@ -203,11 +202,11 @@ class PushtExperienceReplay(AbstractExperienceReplay):
if episode_id == 0: if episode_id == 0:
# hack to initialize tensordict data structure to store episodes # hack to initialize tensordict data structure to store episodes
td_data = episode[0].expand(total_frames).memmap_like(self.data_dir) td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir)
td_data[idxtd : idxtd + len(episode)] = episode td_data[idxtd : idxtd + len(ep_td)] = ep_td
idx0 = idx1 idx0 = idx1
idxtd = idxtd + len(episode) idxtd = idxtd + len(ep_td)
return TensorStorage(td_data.lock_()) return TensorStorage(td_data.lock_())

View File

@ -6,6 +6,10 @@ from omegaconf import OmegaConf
from termcolor import colored from termcolor import colored
def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
def cfg_to_group(cfg, return_list=False): def cfg_to_group(cfg, return_list=False):
"""Return a wandb-safe group name for logging. Optionally returns group name as list.""" """Return a wandb-safe group name for logging. Optionally returns group name as list."""
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)] # lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]

View File

@ -9,13 +9,13 @@ import numpy as np
import torch import torch
import tqdm import tqdm
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from termcolor import colored
from torchrl.envs import EnvBase from torchrl.envs import EnvBase
from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import set_seed from lerobot.common.utils import init_logging, set_seed
def write_video(video_path, stacked_frames, fps): def write_video(video_path, stacked_frames, fps):
@ -109,10 +109,18 @@ def eval(cfg: dict, out_dir=None):
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
assert torch.cuda.is_available() init_logging()
if cfg.device == "cuda":
assert torch.cuda.is_available()
else:
logging.warning("Using CPU, this will be slow.")
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed) set_seed(cfg.seed)
print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir)
log_output_dir(out_dir)
logging.info("make_offline_buffer") logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(cfg) offline_buffer = make_offline_buffer(cfg)
@ -142,6 +150,8 @@ def eval(cfg: dict, out_dir=None):
) )
print(metrics) print(metrics)
logging.info("End of eval")
if __name__ == "__main__": if __name__ == "__main__":
eval_cli() eval_cli()

View File

@ -4,13 +4,12 @@ import hydra
import numpy as np import numpy as np
import torch import torch
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from termcolor import colored
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers import PrioritizedSliceSampler from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import format_big_number, init_logging, set_seed from lerobot.common.utils import format_big_number, init_logging, set_seed
from lerobot.scripts.eval import eval_policy from lerobot.scripts.eval import eval_policy
@ -164,7 +163,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# log metrics to terminal and wandb # log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg) logger = Logger(out_dir, job_name, cfg)
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}") log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})") logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
logging.info(f"{cfg.online_steps=}") logging.info(f"{cfg.online_steps=}")
@ -212,7 +211,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
for env_step in range(cfg.online_steps): for env_step in range(cfg.online_steps):
if env_step == 0: if env_step == 0:
logging.info("Start online training by interacting with environment") logging.info("Start online training by interacting with environment")
# TODO: use SyncDataCollector for that?
# TODO: add configurable number of rollout? (default=1) # TODO: add configurable number of rollout? (default=1)
with torch.no_grad(): with torch.no_grad():
rollout = env.rollout( rollout = env.rollout(
@ -268,6 +266,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
step += 1 step += 1
online_step += 1 online_step += 1
logging.info("End of training")
if __name__ == "__main__": if __name__ == "__main__":
train_cli() train_cli()

View File

@ -1,13 +1,20 @@
import logging
import threading
from pathlib import Path from pathlib import Path
import einops
import hydra import hydra
import imageio import imageio
import torch import torch
from torchrl.data.replay_buffers import SliceSamplerWithoutReplacement from torchrl.data.replay_buffers import (
SamplerWithoutReplacement,
)
from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.logger import log_output_dir
from lerobot.common.utils import init_logging
NUM_EPISODES_TO_RENDER = 10 NUM_EPISODES_TO_RENDER = 50
MAX_NUM_STEPS = 1000 MAX_NUM_STEPS = 1000
FIRST_FRAME = 0 FIRST_FRAME = 0
@ -17,45 +24,88 @@ def visualize_dataset_cli(cfg: dict):
visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir) visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
def cat_and_write_video(video_path, frames, fps):
frames = torch.cat(frames)
assert frames.dtype == torch.uint8
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
imageio.mimsave(video_path, frames, fps=fps)
def visualize_dataset(cfg: dict, out_dir=None): def visualize_dataset(cfg: dict, out_dir=None):
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
sampler = SliceSamplerWithoutReplacement( init_logging()
num_slices=1, log_output_dir(out_dir)
strict_length=False,
# we expect frames of each episode to be stored next to each others sequentially
sampler = SamplerWithoutReplacement(
shuffle=False, shuffle=False,
) )
offline_buffer = make_offline_buffer(cfg, sampler) logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(
cfg, overwrite_sampler=sampler, normalize=False, overwrite_batch_size=1, overwrite_prefetch=12
)
for _ in range(NUM_EPISODES_TO_RENDER): logging.info("Start rendering episodes from offline buffer")
episode = offline_buffer.sample(MAX_NUM_STEPS)
ep_idx = episode["episode"][FIRST_FRAME].item() threads = []
ep_frames = torch.cat( frames = {}
[ current_ep_idx = 0
episode["observation"]["image"][FIRST_FRAME][None, ...], logging.info(f"Visualizing episode {current_ep_idx}")
episode["next", "observation"]["image"], for _ in range(MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER):
], # TODO(rcadene): make it work with bsize > 1
dim=0, ep_td = offline_buffer.sample(1)
) ep_idx = ep_td["episode"][FIRST_FRAME].item()
video_dir = Path(out_dir) / "visualize_dataset" # TODO(rcaene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
video_dir.mkdir(parents=True, exist_ok=True) no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
# TODO(rcadene): make fps configurable new_episode = ep_idx != current_ep_idx
video_path = video_dir / f"episode_{ep_idx}.mp4"
assert ep_frames.min().item() >= 0 if new_episode:
assert ep_frames.max().item() > 1, "Not mendatory, but sanity check" logging.info(f"Visualizing episode {current_ep_idx}")
assert ep_frames.max().item() <= 255
ep_frames = ep_frames.type(torch.uint8)
imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps)
# ran out of episodes for im_key in offline_buffer.image_keys:
if offline_buffer._sampler._sample_list.numel() == 0: if new_episode or no_more_frames:
# append last observed frames (the ones after last action taken)
frames[im_key].append(ep_td[("next", *im_key)])
video_dir = Path(out_dir) / "visualize_dataset"
video_dir.mkdir(parents=True, exist_ok=True)
if len(offline_buffer.image_keys) > 1:
camera = im_key[-1]
video_path = video_dir / f"episode_{current_ep_idx}_{camera}.mp4"
else:
video_path = video_dir / f"episode_{current_ep_idx}.mp4"
thread = threading.Thread(
target=cat_and_write_video,
args=(str(video_path), frames[im_key], cfg.fps),
)
thread.start()
threads.append(thread)
current_ep_idx = ep_idx
# reset list of frames
del frames[im_key]
# append current cameras images to list of frames
if im_key not in frames:
frames[im_key] = []
frames[im_key].append(ep_td[im_key])
if no_more_frames:
logging.info("Ran out of frames")
break break
for thread in threads:
thread.join()
logging.info("End of visualize_dataset")
if __name__ == "__main__": if __name__ == "__main__":
visualize_dataset_cli() visualize_dataset_cli()