fix train.py, stats, eval.py (training is running)
This commit is contained in:
parent
c93ce35d8c
commit
5af00d0c1e
|
@ -91,15 +91,17 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
|
|
||||||
data_dir = self.root / f"{self.dataset_id}"
|
self.data_dir = self.root / f"{self.dataset_id}"
|
||||||
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
|
if (self.data_dir / "data_dict.pth").exists() and (
|
||||||
self.data_dict = torch.load(data_dir / "data_dict.pth")
|
self.data_dir / "data_ids_per_episode.pth"
|
||||||
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
|
).exists():
|
||||||
|
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
||||||
|
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
||||||
else:
|
else:
|
||||||
self._download_and_preproc_obsolete()
|
self._download_and_preproc_obsolete()
|
||||||
data_dir.mkdir(parents=True, exist_ok=True)
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
torch.save(self.data_dict, data_dir / "data_dict.pth")
|
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
||||||
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
|
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
|
|
|
@ -105,15 +105,17 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
|
|
||||||
data_dir = self.root / f"{self.dataset_id}"
|
self.data_dir = self.root / f"{self.dataset_id}"
|
||||||
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
|
if (self.data_dir / "data_dict.pth").exists() and (
|
||||||
self.data_dict = torch.load(data_dir / "data_dict.pth")
|
self.data_dir / "data_ids_per_episode.pth"
|
||||||
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
|
).exists():
|
||||||
|
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
||||||
|
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
||||||
else:
|
else:
|
||||||
self._download_and_preproc_obsolete()
|
self._download_and_preproc_obsolete()
|
||||||
data_dir.mkdir(parents=True, exist_ok=True)
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
torch.save(self.data_dict, data_dir / "data_dict.pth")
|
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
||||||
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
|
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
|
|
|
@ -46,15 +46,17 @@ class SimxarmDataset(torch.utils.data.Dataset):
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
|
|
||||||
data_dir = self.root / f"{self.dataset_id}"
|
self.data_dir = self.root / f"{self.dataset_id}"
|
||||||
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
|
if (self.data_dir / "data_dict.pth").exists() and (
|
||||||
self.data_dict = torch.load(data_dir / "data_dict.pth")
|
self.data_dir / "data_ids_per_episode.pth"
|
||||||
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
|
).exists():
|
||||||
|
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
||||||
|
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
||||||
else:
|
else:
|
||||||
self._download_and_preproc_obsolete()
|
self._download_and_preproc_obsolete()
|
||||||
data_dir.mkdir(parents=True, exist_ok=True)
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
torch.save(self.data_dict, data_dir / "data_dict.pth")
|
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
||||||
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
|
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
|
|
|
@ -112,16 +112,19 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
||||||
|
|
||||||
if max_num_samples is None:
|
if max_num_samples is None:
|
||||||
max_num_samples = len(dataset)
|
max_num_samples = len(dataset)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.")
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=True,
|
shuffle=False,
|
||||||
# pin_memory=cfg.device != "cpu",
|
# pin_memory=cfg.device != "cpu",
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# these einops patterns will be used to aggregate batches and compute statistics
|
||||||
stats_patterns = {
|
stats_patterns = {
|
||||||
"action": "b c -> c",
|
"action": "b c -> c",
|
||||||
"observation.state": "b c -> c",
|
"observation.state": "b c -> c",
|
||||||
|
@ -142,9 +145,9 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
||||||
first_batch = None
|
first_batch = None
|
||||||
running_item_count = 0 # for online mean computation
|
running_item_count = 0 # for online mean computation
|
||||||
for i, batch in enumerate(
|
for i, batch in enumerate(
|
||||||
tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||||
):
|
):
|
||||||
this_batch_size = batch.batch_size[0]
|
this_batch_size = len(batch["index"])
|
||||||
running_item_count += this_batch_size
|
running_item_count += this_batch_size
|
||||||
if first_batch is None:
|
if first_batch is None:
|
||||||
first_batch = deepcopy(batch)
|
first_batch = deepcopy(batch)
|
||||||
|
@ -166,8 +169,10 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
||||||
|
|
||||||
first_batch_ = None
|
first_batch_ = None
|
||||||
running_item_count = 0 # for online std computation
|
running_item_count = 0 # for online std computation
|
||||||
for i, batch in enumerate(tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")):
|
for i, batch in enumerate(
|
||||||
this_batch_size = batch.batch_size[0]
|
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
||||||
|
):
|
||||||
|
this_batch_size = len(batch["index"])
|
||||||
running_item_count += this_batch_size
|
running_item_count += this_batch_size
|
||||||
# Sanity check to make sure the batches are still in the same order as before.
|
# Sanity check to make sure the batches are still in the same order as before.
|
||||||
if first_batch_ is None:
|
if first_batch_ is None:
|
||||||
|
|
|
@ -243,10 +243,9 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
|
||||||
result = {"action": action, "action_pred": action_pred}
|
result = {"action": action, "action_pred": action_pred}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def compute_loss(self, batch):
|
def compute_loss(self, obs_dict, action):
|
||||||
assert "valid_mask" not in batch
|
nobs = obs_dict
|
||||||
nobs = batch["obs"]
|
nactions = action
|
||||||
nactions = batch["action"]
|
|
||||||
batch_size = nactions.shape[0]
|
batch_size = nactions.shape[0]
|
||||||
horizon = nactions.shape[1]
|
horizon = nactions.shape[1]
|
||||||
|
|
||||||
|
|
|
@ -157,7 +157,8 @@ class DiffusionPolicy(nn.Module):
|
||||||
"image": batch["observation.image"],
|
"image": batch["observation.image"],
|
||||||
"agent_pos": batch["observation.state"],
|
"agent_pos": batch["observation.state"],
|
||||||
}
|
}
|
||||||
loss = self.diffusion.compute_loss(obs_dict)
|
action = batch["action"]
|
||||||
|
loss = self.diffusion.compute_loss(obs_dict, action)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
|
|
@ -72,12 +72,12 @@ class NormalizeTransform(Transform):
|
||||||
if inkey not in item:
|
if inkey not in item:
|
||||||
continue
|
continue
|
||||||
if self.mode == "mean_std":
|
if self.mode == "mean_std":
|
||||||
mean = self.stats[f"{inkey}.mean"]
|
mean = self.stats[inkey]["mean"]
|
||||||
std = self.stats[f"{inkey}.std"]
|
std = self.stats[inkey]["std"]
|
||||||
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
|
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
|
||||||
else:
|
else:
|
||||||
min = self.stats[f"{inkey}.min"]
|
min = self.stats[inkey]["min"]
|
||||||
max = self.stats[f"{inkey}.max"]
|
max = self.stats[inkey]["max"]
|
||||||
# normalize to [0,1]
|
# normalize to [0,1]
|
||||||
item[outkey] = (item[inkey] - min) / (max - min)
|
item[outkey] = (item[inkey] - min) / (max - min)
|
||||||
# normalize to [-1, 1]
|
# normalize to [-1, 1]
|
||||||
|
@ -89,12 +89,12 @@ class NormalizeTransform(Transform):
|
||||||
if inkey not in item:
|
if inkey not in item:
|
||||||
continue
|
continue
|
||||||
if self.mode == "mean_std":
|
if self.mode == "mean_std":
|
||||||
mean = self.stats[f"{inkey}.mean"]
|
mean = self.stats[inkey]["mean"]
|
||||||
std = self.stats[f"{inkey}.std"]
|
std = self.stats[inkey]["std"]
|
||||||
item[outkey] = item[inkey] * std + mean
|
item[outkey] = item[inkey] * std + mean
|
||||||
else:
|
else:
|
||||||
min = self.stats[f"{inkey}.min"]
|
min = self.stats[inkey]["min"]
|
||||||
max = self.stats[f"{inkey}.max"]
|
max = self.stats[inkey]["max"]
|
||||||
item[outkey] = (item[inkey] + 1) / 2
|
item[outkey] = (item[inkey] + 1) / 2
|
||||||
item[outkey] = item[outkey] * (max - min) + min
|
item[outkey] = item[outkey] * (max - min) + min
|
||||||
return item
|
return item
|
||||||
|
|
|
@ -37,7 +37,6 @@ from pathlib import Path
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import hydra
|
|
||||||
import imageio
|
import imageio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -47,8 +46,8 @@ from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import log_output_dir
|
from lerobot.common.logger import log_output_dir
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
|
||||||
from lerobot.common.transforms import apply_inverse_transform
|
from lerobot.common.transforms import apply_inverse_transform
|
||||||
|
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||||
|
|
||||||
|
|
||||||
def write_video(video_path, stacked_frames, fps):
|
def write_video(video_path, stacked_frames, fps):
|
||||||
|
@ -92,9 +91,12 @@ def eval_policy(
|
||||||
fps: int = 15,
|
fps: int = 15,
|
||||||
return_first_video: bool = False,
|
return_first_video: bool = False,
|
||||||
transform: callable = None,
|
transform: callable = None,
|
||||||
|
seed=None,
|
||||||
):
|
):
|
||||||
if policy is not None:
|
if policy is not None:
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
device = "cpu" if policy is None else next(policy.parameters()).device
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
sum_rewards = []
|
sum_rewards = []
|
||||||
max_rewards = []
|
max_rewards = []
|
||||||
|
@ -125,11 +127,11 @@ def eval_policy(
|
||||||
policy.reset()
|
policy.reset()
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Policy {policy} doesnt have a `reset` method. This find if the policy doesnt rely on an internal state during rollout."
|
f"Policy {policy} doesnt have a `reset` method. It is required if the policy relies on an internal state during rollout."
|
||||||
)
|
)
|
||||||
|
|
||||||
# reset the environment
|
# reset the environment
|
||||||
observation, info = env.reset(seed=cfg.seed)
|
observation, info = env.reset(seed=seed)
|
||||||
maybe_render_frame(env)
|
maybe_render_frame(env)
|
||||||
|
|
||||||
rewards = []
|
rewards = []
|
||||||
|
@ -138,13 +140,12 @@ def eval_policy(
|
||||||
|
|
||||||
done = torch.tensor([False for _ in env.envs])
|
done = torch.tensor([False for _ in env.envs])
|
||||||
step = 0
|
step = 0
|
||||||
do_rollout = True
|
while not done.all():
|
||||||
while do_rollout:
|
|
||||||
# apply transform to normalize the observations
|
# apply transform to normalize the observations
|
||||||
observation = preprocess_observation(observation, transform)
|
observation = preprocess_observation(observation, transform)
|
||||||
|
|
||||||
# send observation to device/gpu
|
# send observation to device/gpu
|
||||||
observation = {key: observation[key].to(cfg.device, non_blocking=True) for key in observation}
|
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
||||||
|
|
||||||
# get the next action for the environment
|
# get the next action for the environment
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
@ -180,10 +181,6 @@ def eval_policy(
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
if done.all():
|
|
||||||
do_rollout = False
|
|
||||||
break
|
|
||||||
|
|
||||||
rewards = torch.stack(rewards, dim=1)
|
rewards = torch.stack(rewards, dim=1)
|
||||||
successes = torch.stack(successes, dim=1)
|
successes = torch.stack(successes, dim=1)
|
||||||
dones = torch.stack(dones, dim=1)
|
dones = torch.stack(dones, dim=1)
|
||||||
|
@ -295,6 +292,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
fps=cfg.env.fps,
|
fps=cfg.env.fps,
|
||||||
# TODO(rcadene): what should we do with the transform?
|
# TODO(rcadene): what should we do with the transform?
|
||||||
transform=dataset.transform,
|
transform=dataset.transform,
|
||||||
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
print(info["aggregated"])
|
print(info["aggregated"])
|
||||||
|
|
||||||
|
|
|
@ -145,7 +145,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
# )
|
# )
|
||||||
|
|
||||||
logging.info("make_env")
|
logging.info("make_env")
|
||||||
env = make_env(cfg)
|
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||||
|
|
||||||
logging.info("make_policy")
|
logging.info("make_policy")
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
|
@ -173,12 +173,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
eval_info, first_video = eval_policy(
|
eval_info, first_video = eval_policy(
|
||||||
env,
|
env,
|
||||||
policy,
|
policy,
|
||||||
num_episodes=cfg.eval_episodes,
|
|
||||||
max_steps=cfg.env.episode_length,
|
|
||||||
return_first_video=True,
|
return_first_video=True,
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
save_video=True,
|
save_video=True,
|
||||||
transform=dataset.transform,
|
transform=dataset.transform,
|
||||||
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline)
|
log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline)
|
||||||
if cfg.wandb.enable:
|
if cfg.wandb.enable:
|
||||||
|
@ -211,7 +210,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
for key in batch:
|
for key in batch:
|
||||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||||
|
|
||||||
train_info = policy.update(batch, step)
|
train_info = policy(batch, step)
|
||||||
|
|
||||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||||
if step % cfg.log_freq == 0:
|
if step % cfg.log_freq == 0:
|
||||||
|
@ -223,6 +222,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
demo_buffer = dataset if cfg.policy.balanced_sampling else None
|
demo_buffer = dataset if cfg.policy.balanced_sampling else None
|
||||||
online_step = 0
|
online_step = 0
|
||||||
is_offline = False
|
is_offline = False
|
||||||
|
|
|
@ -18,28 +18,26 @@ Example:
|
||||||
import argparse
|
import argparse
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from tensordict import TensorDict
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def mock_dataset(in_data_dir, out_data_dir, num_frames):
|
def mock_dataset(in_data_dir, out_data_dir, num_frames):
|
||||||
in_data_dir = Path(in_data_dir)
|
in_data_dir = Path(in_data_dir)
|
||||||
out_data_dir = Path(out_data_dir)
|
out_data_dir = Path(out_data_dir)
|
||||||
|
|
||||||
# load full dataset as a tensor dict
|
# copy the first `n` frames for each data key so that we have real data
|
||||||
in_td_data = TensorDict.load_memmap(in_data_dir / "replay_buffer")
|
in_data_dict = torch.load(in_data_dir / "data_dict.pth")
|
||||||
|
out_data_dict = {key: in_data_dict[key][:num_frames].clone() for key in in_data_dict}
|
||||||
|
torch.save(out_data_dict, out_data_dir / "data_dict.pth")
|
||||||
|
|
||||||
# use 1 frame to know the specification of the dataset
|
# copy the full mapping between data_id and episode since it's small
|
||||||
# and copy it over `n` frames in the test artifact directory
|
in_ids_per_ep_path = in_data_dir / "data_ids_per_episode.pth"
|
||||||
out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir / "replay_buffer")
|
out_ids_per_ep_path = out_data_dir / "data_ids_per_episode.pth"
|
||||||
|
shutil.copy(in_ids_per_ep_path, out_ids_per_ep_path)
|
||||||
|
|
||||||
# copy the first `n` frames so that we have real data
|
# copy the full statistics of dataset since it's small
|
||||||
out_td_data[:num_frames] = in_td_data[:num_frames].clone()
|
|
||||||
|
|
||||||
# make sure everything has been properly written
|
|
||||||
out_td_data.lock_()
|
|
||||||
|
|
||||||
# copy the full statistics of dataset since it's pretty small
|
|
||||||
in_stats_path = in_data_dir / "stats.pth"
|
in_stats_path = in_data_dir / "stats.pth"
|
||||||
out_stats_path = out_data_dir / "stats.pth"
|
out_stats_path = out_data_dir / "stats.pth"
|
||||||
shutil.copy(in_stats_path, out_stats_path)
|
shutil.copy(in_stats_path, out_stats_path)
|
||||||
|
|
|
@ -59,11 +59,7 @@ def test_factory(env_name, dataset_id):
|
||||||
# )
|
# )
|
||||||
# dataset = make_dataset(cfg)
|
# dataset = make_dataset(cfg)
|
||||||
# # Get all of the data.
|
# # Get all of the data.
|
||||||
# all_data = TensorDictReplayBuffer(
|
# all_data = dataset.data_dict
|
||||||
# storage=buffer._storage,
|
|
||||||
# batch_size=len(buffer),
|
|
||||||
# sampler=SamplerWithoutReplacement(),
|
|
||||||
# ).sample().float()
|
|
||||||
# # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
# # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
||||||
# # computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
# # computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
||||||
# # dataset into even batches.
|
# # dataset into even batches.
|
||||||
|
|
Loading…
Reference in New Issue