Improve visualize_dataset, Improve AbstractReplayBuffer, Small improvements

This commit is contained in:
Remi Cadene 2024-03-06 10:14:03 +00:00
parent 2f80d71c3e
commit f95ecd66fc
7 changed files with 195 additions and 150 deletions

View File

@ -1,6 +1,5 @@
import abc
import logging
import math
from pathlib import Path
from typing import Callable
@ -50,6 +49,22 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
transform=transform,
)
@property
def stats_patterns(self) -> dict:
return {
("observation", "state"): "b c -> 1 c",
("observation", "image"): "b c h w -> 1 c 1 1",
("action"): "b c -> 1 c",
}
@property
def image_keys(self) -> list:
return [("observation", "image")]
@property
def num_cameras(self) -> int:
return len(self.image_keys)
@property
def num_samples(self) -> int:
return len(self)
@ -67,7 +82,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
stats = torch.load(stats_path)
else:
logging.info(f"compute_stats and save to {stats_path}")
stats = self._compute_stats(self._storage._storage, num_batch, batch_size)
stats = self._compute_stats(num_batch, batch_size)
torch.save(stats, stats_path)
return stats
@ -85,101 +100,59 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
def _is_downloaded(self) -> bool:
return self.data_dir.is_dir()
def _compute_stats(self, storage, num_batch=100, batch_size=32):
def _compute_stats(self, num_batch=100, batch_size=32):
rb = TensorDictReplayBuffer(
storage=storage,
storage=self._storage,
batch_size=batch_size,
prefetch=True,
)
batch = rb.sample()
image_channels = batch["observation", "image"].shape[1]
image_mean = torch.zeros(image_channels)
image_std = torch.zeros(image_channels)
image_max = torch.tensor([-math.inf] * image_channels)
image_min = torch.tensor([math.inf] * image_channels)
state_channels = batch["observation", "state"].shape[1]
state_mean = torch.zeros(state_channels)
state_std = torch.zeros(state_channels)
state_max = torch.tensor([-math.inf] * state_channels)
state_min = torch.tensor([math.inf] * state_channels)
action_channels = batch["action"].shape[1]
action_mean = torch.zeros(action_channels)
action_std = torch.zeros(action_channels)
action_max = torch.tensor([-math.inf] * action_channels)
action_min = torch.tensor([math.inf] * action_channels)
mean, std, max, min = {}, {}, {}, {}
# compute mean, min, max
for _ in tqdm.tqdm(range(num_batch)):
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean")
action_mean += einops.reduce(batch["action"], "b c -> c", "mean")
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
image_max = torch.maximum(image_max, b_image_max)
image_min = torch.maximum(image_min, b_image_min)
state_max = torch.maximum(state_max, b_state_max)
state_min = torch.maximum(state_min, b_state_min)
action_max = torch.maximum(action_max, b_action_max)
action_min = torch.maximum(action_min, b_action_min)
batch = rb.sample()
for key, pattern in self.stats_patterns.items():
batch[key] = batch[key].float()
if key not in mean:
# first batch initialize mean, min, max
mean[key] = einops.reduce(batch[key], pattern, "mean")
max[key] = einops.reduce(batch[key], pattern, "max")
min[key] = einops.reduce(batch[key], pattern, "min")
else:
mean[key] += einops.reduce(batch[key], pattern, "mean")
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
batch = rb.sample()
image_mean /= num_batch
state_mean /= num_batch
action_mean /= num_batch
for key in self.stats_patterns:
mean[key] /= num_batch
for i in tqdm.tqdm(range(num_batch)):
b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean")
b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean")
image_std += (b_image_mean - image_mean) ** 2
state_std += (b_state_mean - state_mean) ** 2
action_std += (b_action_mean - action_mean) ** 2
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
image_max = torch.maximum(image_max, b_image_max)
image_min = torch.maximum(image_min, b_image_min)
state_max = torch.maximum(state_max, b_state_max)
state_min = torch.maximum(state_min, b_state_min)
action_max = torch.maximum(action_max, b_action_max)
action_min = torch.maximum(action_min, b_action_min)
if i < num_batch - 1:
# compute std, min, max
for _ in tqdm.tqdm(range(num_batch)):
batch = rb.sample()
for key, pattern in self.stats_patterns.items():
batch[key] = batch[key].float()
batch_mean = einops.reduce(batch[key], pattern, "mean")
if key not in std:
# first batch initialize std
std[key] = (batch_mean - mean[key]) ** 2
else:
std[key] += (batch_mean - mean[key]) ** 2
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
image_std = torch.sqrt(image_std / num_batch)
state_std = torch.sqrt(state_std / num_batch)
action_std = torch.sqrt(action_std / num_batch)
for key in self.stats_patterns:
std[key] = torch.sqrt(std[key] / num_batch)
stats = TensorDict(
{
("observation", "image", "mean"): image_mean[None, :, None, None],
("observation", "image", "std"): image_std[None, :, None, None],
("observation", "image", "max"): image_max[None, :, None, None],
("observation", "image", "min"): image_min[None, :, None, None],
("observation", "state", "mean"): state_mean[None, :],
("observation", "state", "std"): state_std[None, :],
("observation", "state", "max"): state_max[None, :],
("observation", "state", "min"): state_min[None, :],
("action", "mean"): action_mean[None, :],
("action", "std"): action_std[None, :],
("action", "max"): action_max[None, :],
("action", "min"): action_min[None, :],
},
batch_size=[],
)
stats["next", "observation", "image"] = stats["observation", "image"]
stats["next", "observation", "state"] = stats["observation", "state"]
stats = TensorDict({}, batch_size=[])
for key in self.stats_patterns:
stats[(*key, "mean")] = mean[key]
stats[(*key, "std")] = std[key]
stats[(*key, "max")] = max[key]
stats[(*key, "min")] = min[key]
if key[0] == "observation":
# use same stats for the next observations
stats[("next", *key)] = stats[key]
return stats

View File

@ -10,7 +10,9 @@ from lerobot.common.envs.transforms import NormalizeTransform
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
def make_offline_buffer(cfg, sampler=None):
def make_offline_buffer(
cfg, overwrite_sampler=None, normalize=True, overwrite_batch_size=None, overwrite_prefetch=None
):
if cfg.policy.balanced_sampling:
assert cfg.online_steps > 0
batch_size = None
@ -23,9 +25,13 @@ def make_offline_buffer(cfg, sampler=None):
pin_memory = cfg.device == "cuda"
prefetch = cfg.prefetch
overwrite_sampler = sampler is not None
if overwrite_batch_size is not None:
batch_size = overwrite_batch_size
if not overwrite_sampler:
if overwrite_prefetch is not None:
prefetch = overwrite_prefetch
if overwrite_sampler is None:
# TODO(rcadene): move batch_size outside
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
@ -46,6 +52,8 @@ def make_offline_buffer(cfg, sampler=None):
num_slices=num_traj_per_batch,
strict_length=False,
)
else:
sampler = overwrite_sampler
if cfg.env.name == "simxarm":
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
@ -70,15 +78,17 @@ def make_offline_buffer(cfg, sampler=None):
prefetch=prefetch if isinstance(prefetch, int) else None,
)
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
stats = offline_buffer.compute_or_load_stats()
in_keys = [("observation", "state"), ("action")]
if cfg.policy == "tdmpc":
for key in offline_buffer.image_keys:
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
in_keys.append(("observation", "image"))
in_keys.append(key)
# since we use next observations in tdmpc
in_keys.append(("next", "observation", "image"))
in_keys.append(("next", *key))
in_keys.append(("next", "observation", "state"))
if cfg.policy == "diffusion" and cfg.env.name == "pusht":
@ -92,8 +102,7 @@ def make_offline_buffer(cfg, sampler=None):
offline_buffer.set_transform(transform)
if not overwrite_sampler:
num_steps = len(offline_buffer)
index = torch.arange(0, num_steps, 1)
index = torch.arange(0, offline_buffer.num_frames, 1)
sampler.extend(index)
return offline_buffer

View File

@ -183,8 +183,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
# last step of demonstration is considered done
done[-1] = True
print("before " + """episode = TensorDict(""")
episode = TensorDict(
ep_td = TensorDict(
{
("observation", "image"): image[:-1],
("observation", "state"): agent_pos[:-1],
@ -203,11 +202,11 @@ class PushtExperienceReplay(AbstractExperienceReplay):
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = episode[0].expand(total_frames).memmap_like(self.data_dir)
td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir)
td_data[idxtd : idxtd + len(episode)] = episode
td_data[idxtd : idxtd + len(ep_td)] = ep_td
idx0 = idx1
idxtd = idxtd + len(episode)
idxtd = idxtd + len(ep_td)
return TensorStorage(td_data.lock_())

View File

@ -6,6 +6,10 @@ from omegaconf import OmegaConf
from termcolor import colored
def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
def cfg_to_group(cfg, return_list=False):
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]

View File

@ -9,13 +9,13 @@ import numpy as np
import torch
import tqdm
from tensordict.nn import TensorDictModule
from termcolor import colored
from torchrl.envs import EnvBase
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import set_seed
from lerobot.common.utils import init_logging, set_seed
def write_video(video_path, stacked_frames, fps):
@ -109,10 +109,18 @@ def eval(cfg: dict, out_dir=None):
if out_dir is None:
raise NotImplementedError()
init_logging()
if cfg.device == "cuda":
assert torch.cuda.is_available()
else:
logging.warning("Using CPU, this will be slow.")
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed)
print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir)
log_output_dir(out_dir)
logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(cfg)
@ -142,6 +150,8 @@ def eval(cfg: dict, out_dir=None):
)
print(metrics)
logging.info("End of eval")
if __name__ == "__main__":
eval_cli()

View File

@ -4,13 +4,12 @@ import hydra
import numpy as np
import torch
from tensordict.nn import TensorDictModule
from termcolor import colored
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import format_big_number, init_logging, set_seed
from lerobot.scripts.eval import eval_policy
@ -164,7 +163,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg)
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
logging.info(f"{cfg.online_steps=}")
@ -212,7 +211,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
for env_step in range(cfg.online_steps):
if env_step == 0:
logging.info("Start online training by interacting with environment")
# TODO: use SyncDataCollector for that?
# TODO: add configurable number of rollout? (default=1)
with torch.no_grad():
rollout = env.rollout(
@ -268,6 +266,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
step += 1
online_step += 1
logging.info("End of training")
if __name__ == "__main__":
train_cli()

View File

@ -1,13 +1,20 @@
import logging
import threading
from pathlib import Path
import einops
import hydra
import imageio
import torch
from torchrl.data.replay_buffers import SliceSamplerWithoutReplacement
from torchrl.data.replay_buffers import (
SamplerWithoutReplacement,
)
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.logger import log_output_dir
from lerobot.common.utils import init_logging
NUM_EPISODES_TO_RENDER = 10
NUM_EPISODES_TO_RENDER = 50
MAX_NUM_STEPS = 1000
FIRST_FRAME = 0
@ -17,45 +24,88 @@ def visualize_dataset_cli(cfg: dict):
visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
def cat_and_write_video(video_path, frames, fps):
frames = torch.cat(frames)
assert frames.dtype == torch.uint8
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
imageio.mimsave(video_path, frames, fps=fps)
def visualize_dataset(cfg: dict, out_dir=None):
if out_dir is None:
raise NotImplementedError()
sampler = SliceSamplerWithoutReplacement(
num_slices=1,
strict_length=False,
init_logging()
log_output_dir(out_dir)
# we expect frames of each episode to be stored next to each others sequentially
sampler = SamplerWithoutReplacement(
shuffle=False,
)
offline_buffer = make_offline_buffer(cfg, sampler)
for _ in range(NUM_EPISODES_TO_RENDER):
episode = offline_buffer.sample(MAX_NUM_STEPS)
ep_idx = episode["episode"][FIRST_FRAME].item()
ep_frames = torch.cat(
[
episode["observation"]["image"][FIRST_FRAME][None, ...],
episode["next", "observation"]["image"],
],
dim=0,
logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(
cfg, overwrite_sampler=sampler, normalize=False, overwrite_batch_size=1, overwrite_prefetch=12
)
logging.info("Start rendering episodes from offline buffer")
threads = []
frames = {}
current_ep_idx = 0
logging.info(f"Visualizing episode {current_ep_idx}")
for _ in range(MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER):
# TODO(rcadene): make it work with bsize > 1
ep_td = offline_buffer.sample(1)
ep_idx = ep_td["episode"][FIRST_FRAME].item()
# TODO(rcaene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
new_episode = ep_idx != current_ep_idx
if new_episode:
logging.info(f"Visualizing episode {current_ep_idx}")
for im_key in offline_buffer.image_keys:
if new_episode or no_more_frames:
# append last observed frames (the ones after last action taken)
frames[im_key].append(ep_td[("next", *im_key)])
video_dir = Path(out_dir) / "visualize_dataset"
video_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable
video_path = video_dir / f"episode_{ep_idx}.mp4"
assert ep_frames.min().item() >= 0
assert ep_frames.max().item() > 1, "Not mendatory, but sanity check"
assert ep_frames.max().item() <= 255
ep_frames = ep_frames.type(torch.uint8)
imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps)
if len(offline_buffer.image_keys) > 1:
camera = im_key[-1]
video_path = video_dir / f"episode_{current_ep_idx}_{camera}.mp4"
else:
video_path = video_dir / f"episode_{current_ep_idx}.mp4"
# ran out of episodes
if offline_buffer._sampler._sample_list.numel() == 0:
thread = threading.Thread(
target=cat_and_write_video,
args=(str(video_path), frames[im_key], cfg.fps),
)
thread.start()
threads.append(thread)
current_ep_idx = ep_idx
# reset list of frames
del frames[im_key]
# append current cameras images to list of frames
if im_key not in frames:
frames[im_key] = []
frames[im_key].append(ep_td[im_key])
if no_more_frames:
logging.info("Ran out of frames")
break
for thread in threads:
thread.join()
logging.info("End of visualize_dataset")
if __name__ == "__main__":
visualize_dataset_cli()