Merge remote-tracking branch 'upstream/main' into run_resumption

This commit is contained in:
Alexander Soare 2024-05-21 08:48:08 +01:00
commit 54ec151cbb
8 changed files with 162 additions and 196 deletions

View File

@ -20,6 +20,8 @@ build-gpu:
test-end-to-end: test-end-to-end:
${MAKE} test-act-ete-train ${MAKE} test-act-ete-train
${MAKE} test-act-ete-eval ${MAKE} test-act-ete-eval
${MAKE} test-act-ete-train-amp
${MAKE} test-act-ete-eval-amp
${MAKE} test-diffusion-ete-train ${MAKE} test-diffusion-ete-train
${MAKE} test-diffusion-ete-eval ${MAKE} test-diffusion-ete-eval
${MAKE} test-tdmpc-ete-train ${MAKE} test-tdmpc-ete-train
@ -29,6 +31,7 @@ test-end-to-end:
test-act-ete-train: test-act-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=act \ policy=act \
policy.dim_model=64 \
env=aloha \ env=aloha \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=2 \ training.offline_steps=2 \
@ -51,9 +54,40 @@ test-act-ete-eval:
env.episode_length=8 \ env.episode_length=8 \
device=cpu \ device=cpu \
test-act-ete-train-amp:
python lerobot/scripts/train.py \
policy=act \
policy.dim_model=64 \
env=aloha \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
device=cpu \
training.save_model=true \
training.save_freq=2 \
policy.n_action_steps=20 \
policy.chunk_size=20 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/act/ \
use_amp=true
test-act-ete-eval-amp:
python lerobot/scripts/eval.py \
-p tests/outputs/act/checkpoints/000002 \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=cpu \
use_amp=true
test-diffusion-ete-train: test-diffusion-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=diffusion \ policy=diffusion \
policy.down_dims=\[64,128,256\] \
policy.diffusion_step_embed_dim=32 \
policy.num_inference_steps=10 \
env=pusht \ env=pusht \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=2 \ training.offline_steps=2 \
@ -74,6 +108,7 @@ test-diffusion-ete-eval:
env.episode_length=8 \ env.episode_length=8 \
device=cpu \ device=cpu \
# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated.
test-tdmpc-ete-train: test-tdmpc-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
policy=tdmpc \ policy=tdmpc \
@ -82,7 +117,7 @@ test-tdmpc-ete-train:
dataset_repo_id=lerobot/xarm_lift_medium \ dataset_repo_id=lerobot/xarm_lift_medium \
wandb.enable=False \ wandb.enable=False \
training.offline_steps=2 \ training.offline_steps=2 \
training.online_steps=2 \ training.online_steps=0 \
eval.n_episodes=1 \ eval.n_episodes=1 \
eval.batch_size=1 \ eval.batch_size=1 \
env.episode_length=2 \ env.episode_length=2 \
@ -100,7 +135,6 @@ test-tdmpc-ete-eval:
env.episode_length=8 \ env.episode_length=8 \
device=cpu \ device=cpu \
test-default-ete-eval: test-default-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \ --config lerobot/configs/default.yaml \

View File

@ -8,6 +8,7 @@ especially in the context of imitation learning. The most reliable approach is t
on the target environment, whether that be in simulation or the real world. on the target environment, whether that be in simulation or the real world.
""" """
import math
from pathlib import Path from pathlib import Path
import torch import torch
@ -39,11 +40,29 @@ delta_timestamps = {
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
} }
# Load the last 10 episodes of the dataset as a validation set. # Load the last 10% of episodes of the dataset as a validation set.
# The `split` argument utilizes the `datasets` library's syntax for slicing datasets. # - Load full dataset
# For more information on the Slice API, please see: full_dataset = LeRobotDataset("lerobot/pusht", split="train")
# - Calculate train and val subsets
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
num_val_episodes = full_dataset.num_episodes - num_train_episodes
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}")
# - Get first frame index of the validation set
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
# - Load frames subset belonging to validation set using the `split` argument.
# It utilizes the `datasets` library's syntax for slicing datasets.
# For more information on the Slice API, please see:
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits # https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
val_dataset = LeRobotDataset("lerobot/pusht", split="train[24342:]", delta_timestamps=delta_timestamps) train_dataset = LeRobotDataset(
"lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps
)
val_dataset = LeRobotDataset(
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
)
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
# Create dataloader for evaluation. # Create dataloader for evaluation.
val_dataloader = torch.utils.data.DataLoader( val_dataloader = torch.utils.data.DataLoader(

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
import re
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
@ -80,7 +81,23 @@ def hf_transform_to_torch(items_dict):
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset: def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc.""" """hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None: if root is not None:
hf_dataset = load_from_disk(str(Path(root) / repo_id / split)) hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
# TODO(rcadene): clean this which enables getting a subset of dataset
if split != "train":
if "%" in split:
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
match_from = re.search(r"train\[(\d+):\]", split)
match_to = re.search(r"train\[:(\d+)\]", split)
if match_from:
from_frame_index = int(match_from.group(1))
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
elif match_to:
to_frame_index = int(match_to.group(1))
hf_dataset = hf_dataset.select(range(to_frame_index))
else:
raise ValueError(
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
)
else: else:
hf_dataset = load_dataset(repo_id, revision=version, split=split) hf_dataset = load_dataset(repo_id, revision=version, split=split)
hf_dataset.set_transform(hf_transform_to_torch) hf_dataset.set_transform(hf_transform_to_torch)
@ -273,6 +290,12 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
"to": [3, 7, 12] "to": [3, 7, 12]
} }
""" """
if len(hf_dataset) == 0:
episode_data_index = {
"from": torch.tensor([]),
"to": torch.tensor([]),
}
return episode_data_index
for idx, episode_idx in enumerate(hf_dataset["episode_index"]): for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
if episode_idx != current_episode: if episode_idx != current_episode:
# We encountered a new episode, so we append its starting location to the "from" list # We encountered a new episode, so we append its starting location to the "from" list
@ -303,6 +326,8 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
This brings the `episode_index` to the required format. This brings the `episode_index` to the required format.
""" """
if len(hf_dataset) == 0:
return hf_dataset
unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist() unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist()
episode_idx_to_reset_idx_mapping = { episode_idx_to_reset_idx_mapping = {
ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs) ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs)

View File

@ -19,6 +19,9 @@ hydra:
# to change something, you can consider modifying the configuration in the file directly. # to change something, you can consider modifying the configuration in the file directly.
resume: false resume: false
device: cuda # cpu device: cuda # cpu
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: false
# `seed` is used for training (eg: model initialization, dataset shuffling) # `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments. # AND for the evaluation environments.
seed: ??? seed: ???
@ -26,6 +29,7 @@ dataset_repo_id: lerobot/pusht
training: training:
offline_steps: ??? offline_steps: ???
# NOTE: `online_steps` is not implemented yet. It's here as a placeholder.
online_steps: ??? online_steps: ???
online_steps_between_rollouts: ??? online_steps_between_rollouts: ???
online_sampling_ratio: 0.5 online_sampling_ratio: 0.5

View File

@ -5,7 +5,8 @@ dataset_repo_id: lerobot/xarm_lift_medium
training: training:
offline_steps: 25000 offline_steps: 25000
online_steps: 25000 # TODO(alexander-soare): uncomment when online training gets reinstated
online_steps: 0 # 25000 not implemented yet
eval_freq: 5000 eval_freq: 5000
online_steps_between_rollouts: 1 online_steps_between_rollouts: 1
online_sampling_ratio: 0.5 online_sampling_ratio: 0.5

View File

@ -46,6 +46,7 @@ import json
import logging import logging
import threading import threading
import time import time
from contextlib import nullcontext
from copy import deepcopy from copy import deepcopy
from datetime import datetime as dt from datetime import datetime as dt
from pathlib import Path from pathlib import Path
@ -520,7 +521,7 @@ def eval(
raise NotImplementedError() raise NotImplementedError()
# Check device is available # Check device is available
get_safe_torch_device(hydra_cfg.device, log=True) device = get_safe_torch_device(hydra_cfg.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -539,16 +540,17 @@ def eval(
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats) policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
policy.eval() policy.eval()
info = eval_policy( with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
env, info = eval_policy(
policy, env,
hydra_cfg.eval.n_episodes, policy,
max_episodes_rendered=10, hydra_cfg.eval.n_episodes,
video_dir=Path(out_dir) / "eval", max_episodes_rendered=10,
start_seed=hydra_cfg.seed, video_dir=Path(out_dir) / "eval",
enable_progbar=True, start_seed=hydra_cfg.seed,
enable_inner_progbar=True, enable_progbar=True,
) enable_inner_progbar=True,
)
print(info["aggregated"]) print(info["aggregated"])
# Save info # Save info

View File

@ -15,15 +15,14 @@
# limitations under the License. # limitations under the License.
import logging import logging
import time import time
from contextlib import nullcontext
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import datasets
import hydra import hydra
import torch import torch
from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars
from omegaconf import DictConfig from omegaconf import DictConfig
from torch.cuda.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
@ -31,6 +30,7 @@ from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import PolicyWithUpdate from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
get_safe_torch_device, get_safe_torch_device,
@ -70,7 +70,6 @@ def make_optimizer_and_scheduler(cfg, policy):
cfg.training.adam_eps, cfg.training.adam_eps,
cfg.training.adam_weight_decay, cfg.training.adam_weight_decay,
) )
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
@ -88,21 +87,40 @@ def make_optimizer_and_scheduler(cfg, policy):
return optimizer, lr_scheduler return optimizer, lr_scheduler
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): def update_policy(
policy,
batch,
optimizer,
grad_clip_norm,
grad_scaler: GradScaler,
lr_scheduler=None,
use_amp: bool = False,
):
"""Returns a dictionary of items for logging.""" """Returns a dictionary of items for logging."""
start_time = time.time() start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train() policy.train()
output_dict = policy.forward(batch) with torch.autocast(device_type=device.type) if use_amp else nullcontext():
# TODO(rcadene): policy.unnormalize_outputs(out_dict) output_dict = policy.forward(batch)
loss = output_dict["loss"] # TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss.backward() loss = output_dict["loss"]
grad_scaler.scale(loss).backward()
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(), policy.parameters(),
grad_clip_norm, grad_clip_norm,
error_if_nonfinite=False, error_if_nonfinite=False,
) )
optimizer.step() # Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if lr_scheduler is not None: if lr_scheduler is not None:
@ -116,7 +134,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
"loss": loss.item(), "loss": loss.item(),
"grad_norm": float(grad_norm), "grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"], "lr": optimizer.param_groups[0]["lr"],
"update_s": time.time() - start_time, "update_s": time.perf_counter() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"}, **{k: v for k, v in output_dict.items() if k != "loss"},
} }
@ -194,103 +212,6 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
logger.log_dict(info, step, mode="eval") logger.log_dict(info, step, mode="eval")
def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
"""
Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average).
Parameters:
- n_off (int): Number of offline samples, each with a sampling weight of 1.
- n_on (int): Number of online samples.
- pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5).
The total weight of offline samples is n_off * 1.0.
The total weight of offline samples is n_on * w.
The total combined weight of all samples is n_off + n_on * w.
The fraction of the weight that is online is n_on * w / (n_off + n_on * w).
We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on.
The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1))
"""
assert 0.0 <= pc_on <= 1.0
return -(n_off * pc_on) / (n_on * (pc_on - 1))
def add_episodes_inplace(
online_dataset: torch.utils.data.Dataset,
concat_dataset: torch.utils.data.ConcatDataset,
sampler: torch.utils.data.WeightedRandomSampler,
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
pc_online_samples: float,
):
"""
Modifies the online_dataset, concat_dataset, and sampler in place by integrating
new episodes from hf_dataset into the online_dataset, updating the concatenated
dataset's structure and adjusting the sampling strategy based on the specified
percentage of online samples.
Parameters:
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
offline and online datasets, used for sampling purposes.
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
reflect changes in the dataset sizes and specified sampling weights.
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- pc_online_samples (float): The target percentage of samples that should come from
the online dataset during sampling operations.
Raises:
- AssertionError: If the first episode_id or index in hf_dataset is not 0
"""
first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item()
first_index = hf_dataset.select_columns("index")[0]["index"].item()
last_index = hf_dataset.select_columns("index")[-1]["index"].item()
# sanity check
assert first_episode_idx == 0, f"{first_episode_idx=} is not 0"
assert first_index == 0, f"{first_index=} is not 0"
assert first_index == episode_data_index["from"][first_episode_idx].item()
assert last_index == episode_data_index["to"][last_episode_idx].item() - 1
if len(online_dataset) == 0:
# initialize online dataset
online_dataset.hf_dataset = hf_dataset
online_dataset.episode_data_index = episode_data_index
else:
# get the starting indices of the new episodes and frames to be added
start_episode_idx = last_episode_idx + 1
start_index = last_index + 1
def shift_indices(episode_index, index):
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index}
return example
disable_progress_bars() # map has a tqdm progress bar
hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"])
enable_progress_bars()
episode_data_index["from"] += start_index
episode_data_index["to"] += start_index
# extend online dataset
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
# update the concatenated dataset length used during sampling
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
len_online = len(online_dataset)
len_offline = len(concat_dataset) - len_online
weight_offline = 1.0
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
# update the total number of samples used during sampling
sampler.num_samples = len(concat_dataset)
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
@ -302,13 +223,13 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# log metrics to terminal and wandb # log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg) logger = Logger(out_dir, job_name, cfg)
if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1: if cfg.training.online_steps > 0:
logging.warning("eval.batch_size > 1 not supported for online training steps") raise NotImplementedError("Online training is not implemented yet.")
set_global_seed(cfg.seed) set_global_seed(cfg.seed)
# Check device is available # Check device is available
get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -329,6 +250,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Create optimizer and scheduler # Create optimizer and scheduler
# Temporary hack to move optimizer out of policy # Temporary hack to move optimizer out of policy
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(enabled=cfg.use_amp)
step = 0 # number of policy updates (forward + backward + optim) step = 0 # number of policy updates (forward + backward + optim)
@ -361,14 +283,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
def evaluate_and_checkpoint_if_needed(step): def evaluate_and_checkpoint_if_needed(step):
if step % cfg.training.eval_freq == 0: if step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
eval_info = eval_policy( with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
eval_env, eval_info = eval_policy(
policy, eval_env,
cfg.eval.n_episodes, policy,
video_dir=Path(out_dir) / "eval", cfg.eval.n_episodes,
max_episodes_rendered=4, video_dir=Path(out_dir) / "eval",
start_seed=cfg.seed, max_episodes_rendered=4,
) start_seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
if cfg.wandb.enable: if cfg.wandb.enable:
logger.log_video(eval_info["video_paths"][0], step, mode="eval") logger.log_video(eval_info["video_paths"][0], step, mode="eval")
@ -395,22 +318,30 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
num_workers=4, num_workers=4,
batch_size=cfg.training.batch_size, batch_size=cfg.training.batch_size,
shuffle=True, shuffle=True,
pin_memory=cfg.device != "cpu", pin_memory=device.type != "cpu",
drop_last=False, drop_last=False,
) )
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
policy.train() policy.train()
is_offline = True is_offline = True
for offline_step in range(step, cfg.training.offline_steps): for _ in range(step, cfg.training.offline_steps):
if offline_step == 0: if step == 0:
logging.info("Start offline training on a fixed dataset") logging.info("Start offline training on a fixed dataset")
batch = next(dl_iter) batch = next(dl_iter)
for key in batch: for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True) batch[key] = batch[key].to(device, non_blocking=True)
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler) train_info = update_policy(
policy,
batch,
optimizer,
cfg.training.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.training.log_freq == 0: if step % cfg.training.log_freq == 0:
@ -422,9 +353,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
step += 1 step += 1
# create an env dedicated to online episodes collection from policy rollout
online_training_env = make_env(cfg, n_envs=1)
# create an empty online dataset similar to offline dataset # create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset) online_dataset = deepcopy(offline_dataset)
online_dataset.hf_dataset = {} online_dataset.hf_dataset = {}
@ -441,58 +369,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
num_workers=4, num_workers=4,
batch_size=cfg.training.batch_size, batch_size=cfg.training.batch_size,
sampler=sampler, sampler=sampler,
pin_memory=cfg.device != "cpu", pin_memory=device.type != "cpu",
drop_last=False, drop_last=False,
) )
dl_iter = cycle(dataloader)
online_step = 0
is_offline = False
for env_step in range(cfg.training.online_steps):
if env_step == 0:
logging.info("Start online training by interacting with environment")
policy.eval()
with torch.no_grad():
eval_info = eval_policy(
online_training_env,
policy,
n_episodes=1,
return_episode_data=True,
start_seed=cfg.training.online_env_seed,
enable_progbar=True,
)
add_episodes_inplace(
online_dataset,
concat_dataset,
sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"],
pc_online_samples=cfg.training.online_sampling_ratio,
)
policy.train()
for _ in range(cfg.training.online_steps_between_rollouts):
batch = next(dl_iter)
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1)
step += 1
online_step += 1
eval_env.close() eval_env.close()
online_training_env.close()
logging.info("End of training") logging.info("End of training")

View File

@ -111,7 +111,7 @@ def test_examples_2_through_4():
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
), ),
('split="train[24342:]"', 'split="train[-1:]"'), ('split=f"train[{first_val_frame_index}:]"', 'split="train[30:]"'),
("num_workers=4", "num_workers=0"), ("num_workers=4", "num_workers=0"),
('device = torch.device("cuda")', 'device = torch.device("cpu")'), ('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("batch_size=64", "batch_size=1"), ("batch_size=64", "batch_size=1"),