fix train.py, stats, eval.py (training is running)

This commit is contained in:
Cadene 2024-04-05 09:31:39 +00:00
parent c93ce35d8c
commit 5af00d0c1e
11 changed files with 76 additions and 72 deletions

View File

@ -91,15 +91,17 @@ class AlohaDataset(torch.utils.data.Dataset):
self.transform = transform self.transform = transform
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
data_dir = self.root / f"{self.dataset_id}" self.data_dir = self.root / f"{self.dataset_id}"
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists(): if (self.data_dir / "data_dict.pth").exists() and (
self.data_dict = torch.load(data_dir / "data_dict.pth") self.data_dir / "data_ids_per_episode.pth"
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth") ).exists():
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
else: else:
self._download_and_preproc_obsolete() self._download_and_preproc_obsolete()
data_dir.mkdir(parents=True, exist_ok=True) self.data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, data_dir / "data_dict.pth") torch.save(self.data_dict, self.data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth") torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
@property @property
def num_samples(self) -> int: def num_samples(self) -> int:

View File

@ -105,15 +105,17 @@ class PushtDataset(torch.utils.data.Dataset):
self.transform = transform self.transform = transform
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
data_dir = self.root / f"{self.dataset_id}" self.data_dir = self.root / f"{self.dataset_id}"
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists(): if (self.data_dir / "data_dict.pth").exists() and (
self.data_dict = torch.load(data_dir / "data_dict.pth") self.data_dir / "data_ids_per_episode.pth"
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth") ).exists():
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
else: else:
self._download_and_preproc_obsolete() self._download_and_preproc_obsolete()
data_dir.mkdir(parents=True, exist_ok=True) self.data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, data_dir / "data_dict.pth") torch.save(self.data_dict, self.data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth") torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
@property @property
def num_samples(self) -> int: def num_samples(self) -> int:

View File

@ -46,15 +46,17 @@ class SimxarmDataset(torch.utils.data.Dataset):
self.transform = transform self.transform = transform
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
data_dir = self.root / f"{self.dataset_id}" self.data_dir = self.root / f"{self.dataset_id}"
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists(): if (self.data_dir / "data_dict.pth").exists() and (
self.data_dict = torch.load(data_dir / "data_dict.pth") self.data_dir / "data_ids_per_episode.pth"
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth") ).exists():
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
else: else:
self._download_and_preproc_obsolete() self._download_and_preproc_obsolete()
data_dir.mkdir(parents=True, exist_ok=True) self.data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, data_dir / "data_dict.pth") torch.save(self.data_dict, self.data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth") torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
@property @property
def num_samples(self) -> int: def num_samples(self) -> int:

View File

@ -112,16 +112,19 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
if max_num_samples is None: if max_num_samples is None:
max_num_samples = len(dataset) max_num_samples = len(dataset)
else:
raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.")
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,
num_workers=4, num_workers=4,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=False,
# pin_memory=cfg.device != "cpu", # pin_memory=cfg.device != "cpu",
drop_last=False, drop_last=False,
) )
# these einops patterns will be used to aggregate batches and compute statistics
stats_patterns = { stats_patterns = {
"action": "b c -> c", "action": "b c -> c",
"observation.state": "b c -> c", "observation.state": "b c -> c",
@ -142,9 +145,9 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
first_batch = None first_batch = None
running_item_count = 0 # for online mean computation running_item_count = 0 # for online mean computation
for i, batch in enumerate( for i, batch in enumerate(
tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max") tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
): ):
this_batch_size = batch.batch_size[0] this_batch_size = len(batch["index"])
running_item_count += this_batch_size running_item_count += this_batch_size
if first_batch is None: if first_batch is None:
first_batch = deepcopy(batch) first_batch = deepcopy(batch)
@ -166,8 +169,10 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
first_batch_ = None first_batch_ = None
running_item_count = 0 # for online std computation running_item_count = 0 # for online std computation
for i, batch in enumerate(tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")): for i, batch in enumerate(
this_batch_size = batch.batch_size[0] tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size running_item_count += this_batch_size
# Sanity check to make sure the batches are still in the same order as before. # Sanity check to make sure the batches are still in the same order as before.
if first_batch_ is None: if first_batch_ is None:

View File

@ -243,10 +243,9 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
result = {"action": action, "action_pred": action_pred} result = {"action": action, "action_pred": action_pred}
return result return result
def compute_loss(self, batch): def compute_loss(self, obs_dict, action):
assert "valid_mask" not in batch nobs = obs_dict
nobs = batch["obs"] nactions = action
nactions = batch["action"]
batch_size = nactions.shape[0] batch_size = nactions.shape[0]
horizon = nactions.shape[1] horizon = nactions.shape[1]

View File

@ -157,7 +157,8 @@ class DiffusionPolicy(nn.Module):
"image": batch["observation.image"], "image": batch["observation.image"],
"agent_pos": batch["observation.state"], "agent_pos": batch["observation.state"],
} }
loss = self.diffusion.compute_loss(obs_dict) action = batch["action"]
loss = self.diffusion.compute_loss(obs_dict, action)
loss.backward() loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(

View File

@ -72,12 +72,12 @@ class NormalizeTransform(Transform):
if inkey not in item: if inkey not in item:
continue continue
if self.mode == "mean_std": if self.mode == "mean_std":
mean = self.stats[f"{inkey}.mean"] mean = self.stats[inkey]["mean"]
std = self.stats[f"{inkey}.std"] std = self.stats[inkey]["std"]
item[outkey] = (item[inkey] - mean) / (std + 1e-8) item[outkey] = (item[inkey] - mean) / (std + 1e-8)
else: else:
min = self.stats[f"{inkey}.min"] min = self.stats[inkey]["min"]
max = self.stats[f"{inkey}.max"] max = self.stats[inkey]["max"]
# normalize to [0,1] # normalize to [0,1]
item[outkey] = (item[inkey] - min) / (max - min) item[outkey] = (item[inkey] - min) / (max - min)
# normalize to [-1, 1] # normalize to [-1, 1]
@ -89,12 +89,12 @@ class NormalizeTransform(Transform):
if inkey not in item: if inkey not in item:
continue continue
if self.mode == "mean_std": if self.mode == "mean_std":
mean = self.stats[f"{inkey}.mean"] mean = self.stats[inkey]["mean"]
std = self.stats[f"{inkey}.std"] std = self.stats[inkey]["std"]
item[outkey] = item[inkey] * std + mean item[outkey] = item[inkey] * std + mean
else: else:
min = self.stats[f"{inkey}.min"] min = self.stats[inkey]["min"]
max = self.stats[f"{inkey}.max"] max = self.stats[inkey]["max"]
item[outkey] = (item[inkey] + 1) / 2 item[outkey] = (item[inkey] + 1) / 2
item[outkey] = item[outkey] * (max - min) + min item[outkey] = item[outkey] * (max - min) + min
return item return item

View File

@ -37,7 +37,6 @@ from pathlib import Path
import einops import einops
import gymnasium as gym import gymnasium as gym
import hydra
import imageio import imageio
import numpy as np import numpy as np
import torch import torch
@ -47,8 +46,8 @@ from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.logger import log_output_dir from lerobot.common.logger import log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.common.transforms import apply_inverse_transform from lerobot.common.transforms import apply_inverse_transform
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
def write_video(video_path, stacked_frames, fps): def write_video(video_path, stacked_frames, fps):
@ -92,9 +91,12 @@ def eval_policy(
fps: int = 15, fps: int = 15,
return_first_video: bool = False, return_first_video: bool = False,
transform: callable = None, transform: callable = None,
seed=None,
): ):
if policy is not None: if policy is not None:
policy.eval() policy.eval()
device = "cpu" if policy is None else next(policy.parameters()).device
start = time.time() start = time.time()
sum_rewards = [] sum_rewards = []
max_rewards = [] max_rewards = []
@ -125,11 +127,11 @@ def eval_policy(
policy.reset() policy.reset()
else: else:
logging.warning( logging.warning(
f"Policy {policy} doesnt have a `reset` method. This find if the policy doesnt rely on an internal state during rollout." f"Policy {policy} doesnt have a `reset` method. It is required if the policy relies on an internal state during rollout."
) )
# reset the environment # reset the environment
observation, info = env.reset(seed=cfg.seed) observation, info = env.reset(seed=seed)
maybe_render_frame(env) maybe_render_frame(env)
rewards = [] rewards = []
@ -138,13 +140,12 @@ def eval_policy(
done = torch.tensor([False for _ in env.envs]) done = torch.tensor([False for _ in env.envs])
step = 0 step = 0
do_rollout = True while not done.all():
while do_rollout:
# apply transform to normalize the observations # apply transform to normalize the observations
observation = preprocess_observation(observation, transform) observation = preprocess_observation(observation, transform)
# send observation to device/gpu # send observation to device/gpu
observation = {key: observation[key].to(cfg.device, non_blocking=True) for key in observation} observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
# get the next action for the environment # get the next action for the environment
with torch.inference_mode(): with torch.inference_mode():
@ -180,10 +181,6 @@ def eval_policy(
step += 1 step += 1
if done.all():
do_rollout = False
break
rewards = torch.stack(rewards, dim=1) rewards = torch.stack(rewards, dim=1)
successes = torch.stack(successes, dim=1) successes = torch.stack(successes, dim=1)
dones = torch.stack(dones, dim=1) dones = torch.stack(dones, dim=1)
@ -295,6 +292,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
fps=cfg.env.fps, fps=cfg.env.fps,
# TODO(rcadene): what should we do with the transform? # TODO(rcadene): what should we do with the transform?
transform=dataset.transform, transform=dataset.transform,
seed=cfg.seed,
) )
print(info["aggregated"]) print(info["aggregated"])

View File

@ -145,7 +145,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# ) # )
logging.info("make_env") logging.info("make_env")
env = make_env(cfg) env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
logging.info("make_policy") logging.info("make_policy")
policy = make_policy(cfg) policy = make_policy(cfg)
@ -173,12 +173,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
eval_info, first_video = eval_policy( eval_info, first_video = eval_policy(
env, env,
policy, policy,
num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length,
return_first_video=True, return_first_video=True,
video_dir=Path(out_dir) / "eval", video_dir=Path(out_dir) / "eval",
save_video=True, save_video=True,
transform=dataset.transform, transform=dataset.transform,
seed=cfg.seed,
) )
log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline) log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline)
if cfg.wandb.enable: if cfg.wandb.enable:
@ -211,7 +210,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch: for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True) batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy.update(batch, step) train_info = policy(batch, step)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0: if step % cfg.log_freq == 0:
@ -223,6 +222,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
step += 1 step += 1
raise NotImplementedError()
demo_buffer = dataset if cfg.policy.balanced_sampling else None demo_buffer = dataset if cfg.policy.balanced_sampling else None
online_step = 0 online_step = 0
is_offline = False is_offline = False

View File

@ -18,28 +18,26 @@ Example:
import argparse import argparse
import shutil import shutil
from tensordict import TensorDict
from pathlib import Path from pathlib import Path
import torch
def mock_dataset(in_data_dir, out_data_dir, num_frames): def mock_dataset(in_data_dir, out_data_dir, num_frames):
in_data_dir = Path(in_data_dir) in_data_dir = Path(in_data_dir)
out_data_dir = Path(out_data_dir) out_data_dir = Path(out_data_dir)
# load full dataset as a tensor dict # copy the first `n` frames for each data key so that we have real data
in_td_data = TensorDict.load_memmap(in_data_dir / "replay_buffer") in_data_dict = torch.load(in_data_dir / "data_dict.pth")
out_data_dict = {key: in_data_dict[key][:num_frames].clone() for key in in_data_dict}
torch.save(out_data_dict, out_data_dir / "data_dict.pth")
# use 1 frame to know the specification of the dataset # copy the full mapping between data_id and episode since it's small
# and copy it over `n` frames in the test artifact directory in_ids_per_ep_path = in_data_dir / "data_ids_per_episode.pth"
out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir / "replay_buffer") out_ids_per_ep_path = out_data_dir / "data_ids_per_episode.pth"
shutil.copy(in_ids_per_ep_path, out_ids_per_ep_path)
# copy the first `n` frames so that we have real data # copy the full statistics of dataset since it's small
out_td_data[:num_frames] = in_td_data[:num_frames].clone()
# make sure everything has been properly written
out_td_data.lock_()
# copy the full statistics of dataset since it's pretty small
in_stats_path = in_data_dir / "stats.pth" in_stats_path = in_data_dir / "stats.pth"
out_stats_path = out_data_dir / "stats.pth" out_stats_path = out_data_dir / "stats.pth"
shutil.copy(in_stats_path, out_stats_path) shutil.copy(in_stats_path, out_stats_path)

View File

@ -59,11 +59,7 @@ def test_factory(env_name, dataset_id):
# ) # )
# dataset = make_dataset(cfg) # dataset = make_dataset(cfg)
# # Get all of the data. # # Get all of the data.
# all_data = TensorDictReplayBuffer( # all_data = dataset.data_dict
# storage=buffer._storage,
# batch_size=len(buffer),
# sampler=SamplerWithoutReplacement(),
# ).sample().float()
# # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched # # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# # computation of the statistics. While doing this, we also make sure it works when we don't divide the # # computation of the statistics. While doing this, we also make sure it works when we don't divide the
# # dataset into even batches. # # dataset into even batches.