test_examples are passing

This commit is contained in:
Cadene 2024-04-10 13:45:45 +00:00
parent 6082a7bc73
commit c08003278e
4 changed files with 62 additions and 79 deletions

View File

@ -1,6 +1,5 @@
import os
from torchrl.data.replay_buffers import SamplerWithoutReplacement
from pathlib import Path
import lerobot
from lerobot.common.datasets.aloha import AlohaDataset
@ -9,16 +8,13 @@ from lerobot.scripts.visualize_dataset import render_dataset
print(lerobot.available_datasets)
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
# we use this sampler to sample 1 frame after the other
sampler = SamplerWithoutReplacement(shuffle=False)
dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler, root=os.environ.get("DATA_DIR"))
# TODO(rcadene): remove DATA_DIR
dataset = AlohaDataset("aloha_sim_transfer_cube_human", root=Path(os.environ.get("DATA_DIR")))
video_paths = render_dataset(
dataset,
out_dir="outputs/visualize_dataset/example",
max_num_samples=300,
fps=50,
max_num_episodes=1,
)
print(video_paths)
# ['outputs/visualize_dataset/example/episode_0.mp4']

View File

@ -9,9 +9,8 @@ from pathlib import Path
import torch
from omegaconf import OmegaConf
from tqdm import trange
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
from lerobot.common.utils import init_hydra_config
@ -37,19 +36,33 @@ policy = DiffusionPolicy(
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps,
**cfg.policy,
)
policy.train()
offline_buffer = make_offline_buffer(cfg)
dataset = make_dataset(cfg)
# create dataloader for offline training
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=cfg.policy.batch_size,
shuffle=True,
pin_memory=cfg.device != "cpu",
drop_last=True,
)
for step, batch in enumerate(dataloader):
info = policy(batch, step)
if step % cfg.log_freq == 0:
num_samples = (step + 1) * cfg.policy.batch_size
loss = info["loss"]
update_s = info["update_s"]
print(f"step:{step} samples:{num_samples} loss:{loss:.3f} update_time:{update_s:.3f}(seconds)")
for offline_step in trange(cfg.offline_steps):
train_info = policy.update(offline_buffer, offline_step)
if offline_step % cfg.log_freq == 0:
print(train_info)
# Save the policy, configuration, and normalization stats for later use.
policy.save(output_directory / "model.pt")
OmegaConf.save(cfg, output_directory / "config.yaml")
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")
torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")

View File

@ -40,7 +40,8 @@ def make_dataset(
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
# min_max_from_spec
# stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
stats = {}
@ -51,7 +52,7 @@ def make_dataset(
stats["action"] = {}
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
else:
elif stats_path is None:
# instantiate a one frame dataset with light transform
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
@ -59,9 +60,8 @@ def make_dataset(
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = compute_or_load_stats(stats_dataset)
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
else:
stats = torch.load(stats_path)
transforms = v2.Compose(
[

View File

@ -6,9 +6,6 @@ import einops
import hydra
import imageio
import torch
from torchrl.data.replay_buffers import (
SamplerWithoutReplacement,
)
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.logger import log_output_dir
@ -39,19 +36,11 @@ def visualize_dataset(cfg: dict, out_dir=None):
init_logging()
log_output_dir(out_dir)
# we expect frames of each episode to be stored next to each others sequentially
sampler = SamplerWithoutReplacement(
shuffle=False,
)
logging.info("make_dataset")
dataset = make_dataset(
cfg,
overwrite_sampler=sampler,
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
normalize=False,
overwrite_batch_size=1,
overwrite_prefetch=12,
)
logging.info("Start rendering episodes from offline buffer")
@ -60,64 +49,49 @@ def visualize_dataset(cfg: dict, out_dir=None):
logging.info(video_path)
def render_dataset(dataset, out_dir, max_num_samples, fps):
def render_dataset(dataset, out_dir, max_num_episodes):
out_dir = Path(out_dir)
video_paths = []
threads = []
frames = {}
current_ep_idx = 0
logging.info(f"Visualizing episode {current_ep_idx}")
for i in range(max_num_samples):
# TODO(rcadene): make it work with bsize > 1
ep_td = dataset.sample(1)
ep_idx = ep_td["episode"][FIRST_FRAME].item()
# TODO(rcadene): modify dataset._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
num_frames_left = dataset._sampler._sample_list.numel()
episode_is_done = ep_idx != current_ep_idx
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=1,
shuffle=False,
)
dl_iter = iter(dataloader)
if episode_is_done:
logging.info(f"Rendering episode {current_ep_idx}")
num_episodes = len(dataset.data_ids_per_episode)
for ep_id in range(min(max_num_episodes, num_episodes)):
logging.info(f"Rendering episode {ep_id}")
for im_key in dataset.image_keys:
if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1):
frames = {}
for _ in dataset.data_ids_per_episode[ep_id]:
item = next(dl_iter)
for im_key in dataset.image_keys:
# when first frame of episode, initialize frames dict
if im_key not in frames:
frames[im_key] = []
# add current frame to list of frames to render
frames[im_key].append(ep_td[im_key])
frames[im_key].append(item[im_key])
out_dir.mkdir(parents=True, exist_ok=True)
for im_key in dataset.image_keys:
if len(dataset.image_keys) > 0:
im_name = im_key.replace("observation.images.", "")
video_path = out_dir / f"episode_{ep_id}_{im_name}.mp4"
else:
# When episode has no more frame in its list of observation,
# one frame still remains. It is the result of the last action taken.
# It is stored in `"next"`, so we add it to the list of frames to render.
frames[im_key].append(ep_td["next"][im_key])
video_path = out_dir / f"episode_{ep_id}.mp4"
video_paths.append(video_path)
out_dir.mkdir(parents=True, exist_ok=True)
if len(dataset.image_keys) > 1:
camera = im_key[-1]
video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4"
else:
video_path = out_dir / f"episode_{current_ep_idx}.mp4"
video_paths.append(str(video_path))
thread = threading.Thread(
target=cat_and_write_video,
args=(str(video_path), frames[im_key], fps),
)
thread.start()
threads.append(thread)
current_ep_idx = ep_idx
# reset list of frames
del frames[im_key]
if num_frames_left == 0:
logging.info("Ran out of frames")
break
if current_ep_idx == NUM_EPISODES_TO_RENDER:
break
thread = threading.Thread(
target=cat_and_write_video,
args=(str(video_path), frames[im_key], dataset.fps),
)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()