fix train.py, stats, eval.py (training is running)

This commit is contained in:
Cadene 2024-04-05 09:31:39 +00:00
parent c93ce35d8c
commit 5af00d0c1e
11 changed files with 76 additions and 72 deletions

View File

@ -91,15 +91,17 @@ class AlohaDataset(torch.utils.data.Dataset):
self.transform = transform
self.delta_timestamps = delta_timestamps
data_dir = self.root / f"{self.dataset_id}"
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
self.data_dict = torch.load(data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
self.data_dir = self.root / f"{self.dataset_id}"
if (self.data_dir / "data_dict.pth").exists() and (
self.data_dir / "data_ids_per_episode.pth"
).exists():
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
else:
self._download_and_preproc_obsolete()
data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
self.data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
@property
def num_samples(self) -> int:

View File

@ -105,15 +105,17 @@ class PushtDataset(torch.utils.data.Dataset):
self.transform = transform
self.delta_timestamps = delta_timestamps
data_dir = self.root / f"{self.dataset_id}"
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
self.data_dict = torch.load(data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
self.data_dir = self.root / f"{self.dataset_id}"
if (self.data_dir / "data_dict.pth").exists() and (
self.data_dir / "data_ids_per_episode.pth"
).exists():
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
else:
self._download_and_preproc_obsolete()
data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
self.data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
@property
def num_samples(self) -> int:

View File

@ -46,15 +46,17 @@ class SimxarmDataset(torch.utils.data.Dataset):
self.transform = transform
self.delta_timestamps = delta_timestamps
data_dir = self.root / f"{self.dataset_id}"
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
self.data_dict = torch.load(data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
self.data_dir = self.root / f"{self.dataset_id}"
if (self.data_dir / "data_dict.pth").exists() and (
self.data_dir / "data_ids_per_episode.pth"
).exists():
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
else:
self._download_and_preproc_obsolete()
data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
self.data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
@property
def num_samples(self) -> int:

View File

@ -112,16 +112,19 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
if max_num_samples is None:
max_num_samples = len(dataset)
else:
raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.")
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=batch_size,
shuffle=True,
shuffle=False,
# pin_memory=cfg.device != "cpu",
drop_last=False,
)
# these einops patterns will be used to aggregate batches and compute statistics
stats_patterns = {
"action": "b c -> c",
"observation.state": "b c -> c",
@ -142,9 +145,9 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
first_batch = None
running_item_count = 0 # for online mean computation
for i, batch in enumerate(
tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
this_batch_size = batch.batch_size[0]
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
if first_batch is None:
first_batch = deepcopy(batch)
@ -166,8 +169,10 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
first_batch_ = None
running_item_count = 0 # for online std computation
for i, batch in enumerate(tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")):
this_batch_size = batch.batch_size[0]
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
# Sanity check to make sure the batches are still in the same order as before.
if first_batch_ is None:

View File

@ -243,10 +243,9 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
result = {"action": action, "action_pred": action_pred}
return result
def compute_loss(self, batch):
assert "valid_mask" not in batch
nobs = batch["obs"]
nactions = batch["action"]
def compute_loss(self, obs_dict, action):
nobs = obs_dict
nactions = action
batch_size = nactions.shape[0]
horizon = nactions.shape[1]

View File

@ -157,7 +157,8 @@ class DiffusionPolicy(nn.Module):
"image": batch["observation.image"],
"agent_pos": batch["observation.state"],
}
loss = self.diffusion.compute_loss(obs_dict)
action = batch["action"]
loss = self.diffusion.compute_loss(obs_dict, action)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(

View File

@ -72,12 +72,12 @@ class NormalizeTransform(Transform):
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[f"{inkey}.mean"]
std = self.stats[f"{inkey}.std"]
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
else:
min = self.stats[f"{inkey}.min"]
max = self.stats[f"{inkey}.max"]
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
# normalize to [0,1]
item[outkey] = (item[inkey] - min) / (max - min)
# normalize to [-1, 1]
@ -89,12 +89,12 @@ class NormalizeTransform(Transform):
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[f"{inkey}.mean"]
std = self.stats[f"{inkey}.std"]
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
item[outkey] = item[inkey] * std + mean
else:
min = self.stats[f"{inkey}.min"]
max = self.stats[f"{inkey}.max"]
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
item[outkey] = (item[inkey] + 1) / 2
item[outkey] = item[outkey] * (max - min) + min
return item

View File

@ -37,7 +37,6 @@ from pathlib import Path
import einops
import gymnasium as gym
import hydra
import imageio
import numpy as np
import torch
@ -47,8 +46,8 @@ from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.common.transforms import apply_inverse_transform
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
def write_video(video_path, stacked_frames, fps):
@ -92,9 +91,12 @@ def eval_policy(
fps: int = 15,
return_first_video: bool = False,
transform: callable = None,
seed=None,
):
if policy is not None:
policy.eval()
device = "cpu" if policy is None else next(policy.parameters()).device
start = time.time()
sum_rewards = []
max_rewards = []
@ -125,11 +127,11 @@ def eval_policy(
policy.reset()
else:
logging.warning(
f"Policy {policy} doesnt have a `reset` method. This find if the policy doesnt rely on an internal state during rollout."
f"Policy {policy} doesnt have a `reset` method. It is required if the policy relies on an internal state during rollout."
)
# reset the environment
observation, info = env.reset(seed=cfg.seed)
observation, info = env.reset(seed=seed)
maybe_render_frame(env)
rewards = []
@ -138,13 +140,12 @@ def eval_policy(
done = torch.tensor([False for _ in env.envs])
step = 0
do_rollout = True
while do_rollout:
while not done.all():
# apply transform to normalize the observations
observation = preprocess_observation(observation, transform)
# send observation to device/gpu
observation = {key: observation[key].to(cfg.device, non_blocking=True) for key in observation}
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
# get the next action for the environment
with torch.inference_mode():
@ -180,10 +181,6 @@ def eval_policy(
step += 1
if done.all():
do_rollout = False
break
rewards = torch.stack(rewards, dim=1)
successes = torch.stack(successes, dim=1)
dones = torch.stack(dones, dim=1)
@ -295,6 +292,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
fps=cfg.env.fps,
# TODO(rcadene): what should we do with the transform?
transform=dataset.transform,
seed=cfg.seed,
)
print(info["aggregated"])

View File

@ -145,7 +145,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# )
logging.info("make_env")
env = make_env(cfg)
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
logging.info("make_policy")
policy = make_policy(cfg)
@ -173,12 +173,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
eval_info, first_video = eval_policy(
env,
policy,
num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length,
return_first_video=True,
video_dir=Path(out_dir) / "eval",
save_video=True,
transform=dataset.transform,
seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline)
if cfg.wandb.enable:
@ -211,7 +210,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy.update(batch, step)
train_info = policy(batch, step)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0:
@ -223,6 +222,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
step += 1
raise NotImplementedError()
demo_buffer = dataset if cfg.policy.balanced_sampling else None
online_step = 0
is_offline = False

View File

@ -18,28 +18,26 @@ Example:
import argparse
import shutil
from tensordict import TensorDict
from pathlib import Path
import torch
def mock_dataset(in_data_dir, out_data_dir, num_frames):
in_data_dir = Path(in_data_dir)
out_data_dir = Path(out_data_dir)
# load full dataset as a tensor dict
in_td_data = TensorDict.load_memmap(in_data_dir / "replay_buffer")
# copy the first `n` frames for each data key so that we have real data
in_data_dict = torch.load(in_data_dir / "data_dict.pth")
out_data_dict = {key: in_data_dict[key][:num_frames].clone() for key in in_data_dict}
torch.save(out_data_dict, out_data_dir / "data_dict.pth")
# use 1 frame to know the specification of the dataset
# and copy it over `n` frames in the test artifact directory
out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir / "replay_buffer")
# copy the full mapping between data_id and episode since it's small
in_ids_per_ep_path = in_data_dir / "data_ids_per_episode.pth"
out_ids_per_ep_path = out_data_dir / "data_ids_per_episode.pth"
shutil.copy(in_ids_per_ep_path, out_ids_per_ep_path)
# copy the first `n` frames so that we have real data
out_td_data[:num_frames] = in_td_data[:num_frames].clone()
# make sure everything has been properly written
out_td_data.lock_()
# copy the full statistics of dataset since it's pretty small
# copy the full statistics of dataset since it's small
in_stats_path = in_data_dir / "stats.pth"
out_stats_path = out_data_dir / "stats.pth"
shutil.copy(in_stats_path, out_stats_path)

View File

@ -59,11 +59,7 @@ def test_factory(env_name, dataset_id):
# )
# dataset = make_dataset(cfg)
# # Get all of the data.
# all_data = TensorDictReplayBuffer(
# storage=buffer._storage,
# batch_size=len(buffer),
# sampler=SamplerWithoutReplacement(),
# ).sample().float()
# all_data = dataset.data_dict
# # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# # computation of the statistics. While doing this, we also make sure it works when we don't divide the
# # dataset into even batches.