Disable online training (#202)
Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
parent
c4da689171
commit
2b270d085b
3
Makefile
3
Makefile
|
@ -74,6 +74,7 @@ test-diffusion-ete-eval:
|
|||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
|
||||
# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated.
|
||||
test-tdmpc-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=tdmpc \
|
||||
|
@ -82,7 +83,7 @@ test-tdmpc-ete-train:
|
|||
dataset_repo_id=lerobot/xarm_lift_medium \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=2 \
|
||||
|
|
|
@ -17,6 +17,7 @@ dataset_repo_id: lerobot/pusht
|
|||
|
||||
training:
|
||||
offline_steps: ???
|
||||
# NOTE: `online_steps` is not implemented yet. It's here as a placeholder.
|
||||
online_steps: ???
|
||||
online_steps_between_rollouts: ???
|
||||
online_sampling_ratio: 0.5
|
||||
|
|
|
@ -5,7 +5,8 @@ dataset_repo_id: lerobot/xarm_lift_medium
|
|||
|
||||
training:
|
||||
offline_steps: 25000
|
||||
online_steps: 25000
|
||||
# TODO(alexander-soare): uncomment when online training gets reinstated
|
||||
online_steps: 0 # 25000 not implemented yet
|
||||
eval_freq: 5000
|
||||
online_steps_between_rollouts: 1
|
||||
online_sampling_ratio: 0.5
|
||||
|
|
|
@ -18,11 +18,8 @@ import time
|
|||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import hydra
|
||||
import torch
|
||||
from datasets import concatenate_datasets
|
||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
@ -69,7 +66,6 @@ def make_optimizer_and_scheduler(cfg, policy):
|
|||
cfg.training.adam_eps,
|
||||
cfg.training.adam_weight_decay,
|
||||
)
|
||||
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
|
@ -211,103 +207,6 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
|
|||
logger.log_dict(info, step, mode="eval")
|
||||
|
||||
|
||||
def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
|
||||
"""
|
||||
Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average).
|
||||
|
||||
Parameters:
|
||||
- n_off (int): Number of offline samples, each with a sampling weight of 1.
|
||||
- n_on (int): Number of online samples.
|
||||
- pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5).
|
||||
|
||||
The total weight of offline samples is n_off * 1.0.
|
||||
The total weight of offline samples is n_on * w.
|
||||
The total combined weight of all samples is n_off + n_on * w.
|
||||
The fraction of the weight that is online is n_on * w / (n_off + n_on * w).
|
||||
We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on.
|
||||
The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1))
|
||||
"""
|
||||
assert 0.0 <= pc_on <= 1.0
|
||||
return -(n_off * pc_on) / (n_on * (pc_on - 1))
|
||||
|
||||
|
||||
def add_episodes_inplace(
|
||||
online_dataset: torch.utils.data.Dataset,
|
||||
concat_dataset: torch.utils.data.ConcatDataset,
|
||||
sampler: torch.utils.data.WeightedRandomSampler,
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
pc_online_samples: float,
|
||||
):
|
||||
"""
|
||||
Modifies the online_dataset, concat_dataset, and sampler in place by integrating
|
||||
new episodes from hf_dataset into the online_dataset, updating the concatenated
|
||||
dataset's structure and adjusting the sampling strategy based on the specified
|
||||
percentage of online samples.
|
||||
|
||||
Parameters:
|
||||
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
|
||||
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
|
||||
offline and online datasets, used for sampling purposes.
|
||||
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
|
||||
reflect changes in the dataset sizes and specified sampling weights.
|
||||
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
||||
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
|
||||
They indicate the start index and end index of each episode in the dataset.
|
||||
- pc_online_samples (float): The target percentage of samples that should come from
|
||||
the online dataset during sampling operations.
|
||||
|
||||
Raises:
|
||||
- AssertionError: If the first episode_id or index in hf_dataset is not 0
|
||||
"""
|
||||
first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
|
||||
last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item()
|
||||
first_index = hf_dataset.select_columns("index")[0]["index"].item()
|
||||
last_index = hf_dataset.select_columns("index")[-1]["index"].item()
|
||||
# sanity check
|
||||
assert first_episode_idx == 0, f"{first_episode_idx=} is not 0"
|
||||
assert first_index == 0, f"{first_index=} is not 0"
|
||||
assert first_index == episode_data_index["from"][first_episode_idx].item()
|
||||
assert last_index == episode_data_index["to"][last_episode_idx].item() - 1
|
||||
|
||||
if len(online_dataset) == 0:
|
||||
# initialize online dataset
|
||||
online_dataset.hf_dataset = hf_dataset
|
||||
online_dataset.episode_data_index = episode_data_index
|
||||
else:
|
||||
# get the starting indices of the new episodes and frames to be added
|
||||
start_episode_idx = last_episode_idx + 1
|
||||
start_index = last_index + 1
|
||||
|
||||
def shift_indices(episode_index, index):
|
||||
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
|
||||
example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index}
|
||||
return example
|
||||
|
||||
disable_progress_bars() # map has a tqdm progress bar
|
||||
hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"])
|
||||
enable_progress_bars()
|
||||
|
||||
episode_data_index["from"] += start_index
|
||||
episode_data_index["to"] += start_index
|
||||
|
||||
# extend online dataset
|
||||
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
|
||||
|
||||
# update the concatenated dataset length used during sampling
|
||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||
|
||||
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
|
||||
len_online = len(online_dataset)
|
||||
len_offline = len(concat_dataset) - len_online
|
||||
weight_offline = 1.0
|
||||
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
|
||||
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
|
||||
|
||||
# update the total number of samples used during sampling
|
||||
sampler.num_samples = len(concat_dataset)
|
||||
|
||||
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
|
@ -316,8 +215,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
init_logging()
|
||||
|
||||
if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1:
|
||||
logging.warning("eval.batch_size > 1 not supported for online training steps")
|
||||
if cfg.training.online_steps > 0:
|
||||
raise NotImplementedError("Online training is not implemented yet.")
|
||||
|
||||
# Check device is available
|
||||
get_safe_torch_device(cfg.device, log=True)
|
||||
|
@ -395,10 +294,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
step = 0 # number of policy update (forward + backward + optim)
|
||||
is_offline = True
|
||||
for offline_step in range(cfg.training.offline_steps):
|
||||
if offline_step == 0:
|
||||
for step in range(cfg.training.offline_steps):
|
||||
if step == 0:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
batch = next(dl_iter)
|
||||
|
||||
|
@ -415,11 +313,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
# so we pass in step + 1.
|
||||
evaluate_and_checkpoint_if_needed(step + 1)
|
||||
|
||||
step += 1
|
||||
|
||||
# create an env dedicated to online episodes collection from policy rollout
|
||||
online_training_env = make_env(cfg, n_envs=1)
|
||||
|
||||
# create an empty online dataset similar to offline dataset
|
||||
online_dataset = deepcopy(offline_dataset)
|
||||
online_dataset.hf_dataset = {}
|
||||
|
@ -439,55 +332,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
pin_memory=cfg.device != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
online_step = 0
|
||||
is_offline = False
|
||||
for env_step in range(cfg.training.online_steps):
|
||||
if env_step == 0:
|
||||
logging.info("Start online training by interacting with environment")
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
eval_info = eval_policy(
|
||||
online_training_env,
|
||||
policy,
|
||||
n_episodes=1,
|
||||
return_episode_data=True,
|
||||
start_seed=cfg.training.online_env_seed,
|
||||
enable_progbar=True,
|
||||
)
|
||||
|
||||
add_episodes_inplace(
|
||||
online_dataset,
|
||||
concat_dataset,
|
||||
sampler,
|
||||
hf_dataset=eval_info["episodes"]["hf_dataset"],
|
||||
episode_data_index=eval_info["episodes"]["episode_data_index"],
|
||||
pc_online_samples=cfg.training.online_sampling_ratio,
|
||||
)
|
||||
|
||||
policy.train()
|
||||
for _ in range(cfg.training.online_steps_between_rollouts):
|
||||
batch = next(dl_iter)
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
|
||||
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
|
||||
|
||||
if step % cfg.training.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
|
||||
|
||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||
# so we pass in step + 1.
|
||||
evaluate_and_checkpoint_if_needed(step + 1)
|
||||
|
||||
step += 1
|
||||
online_step += 1
|
||||
|
||||
eval_env.close()
|
||||
online_training_env.close()
|
||||
logging.info("End of training")
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue