[Port HIL-SERL] Balanced sampler function speed up and refactor to align with train.py (#715)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
db78fee9de
commit
83b2dc1219
|
@ -3,6 +3,14 @@
|
||||||
defaults:
|
defaults:
|
||||||
- _self_
|
- _self_
|
||||||
|
|
||||||
|
hydra:
|
||||||
|
run:
|
||||||
|
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||||
|
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||||
|
dir: outputs/train_hilserl_classifier/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${hydra.job.name}
|
||||||
|
job:
|
||||||
|
name: default
|
||||||
|
|
||||||
seed: 13
|
seed: 13
|
||||||
dataset_repo_id: aractingi/push_cube_square_light_reward_cropped_resized
|
dataset_repo_id: aractingi/push_cube_square_light_reward_cropped_resized
|
||||||
# aractingi/push_cube_square_reward_1_cropped_resized
|
# aractingi/push_cube_square_reward_1_cropped_resized
|
||||||
|
|
|
@ -42,6 +42,7 @@ from lerobot.common.utils.utils import (
|
||||||
format_big_number,
|
format_big_number,
|
||||||
get_safe_torch_device,
|
get_safe_torch_device,
|
||||||
init_hydra_config,
|
init_hydra_config,
|
||||||
|
init_logging,
|
||||||
set_global_seed,
|
set_global_seed,
|
||||||
)
|
)
|
||||||
from lerobot.scripts.server.buffer import random_shift
|
from lerobot.scripts.server.buffer import random_shift
|
||||||
|
@ -60,9 +61,24 @@ def get_model(cfg, logger): # noqa I001
|
||||||
|
|
||||||
|
|
||||||
def create_balanced_sampler(dataset, cfg):
|
def create_balanced_sampler(dataset, cfg):
|
||||||
# Creates a weighted sampler to handle class imbalance
|
# Get underlying dataset if using Subset
|
||||||
|
original_dataset = (
|
||||||
|
dataset.dataset if isinstance(dataset, torch.utils.data.Subset) else dataset
|
||||||
|
)
|
||||||
|
|
||||||
labels = torch.tensor([item[cfg.training.label_key] for item in dataset])
|
# Get indices if using Subset (for slicing)
|
||||||
|
indices = dataset.indices if isinstance(dataset, torch.utils.data.Subset) else None
|
||||||
|
|
||||||
|
# Get labels from Hugging Face dataset
|
||||||
|
if indices is not None:
|
||||||
|
# Get subset of labels using Hugging Face's select()
|
||||||
|
hf_subset = original_dataset.hf_dataset.select(indices)
|
||||||
|
labels = hf_subset[cfg.training.label_key]
|
||||||
|
else:
|
||||||
|
# Get all labels directly
|
||||||
|
labels = original_dataset.hf_dataset[cfg.training.label_key]
|
||||||
|
|
||||||
|
labels = torch.stack(labels)
|
||||||
_, counts = torch.unique(labels, return_counts=True)
|
_, counts = torch.unique(labels, return_counts=True)
|
||||||
class_weights = 1.0 / counts.float()
|
class_weights = 1.0 / counts.float()
|
||||||
sample_weights = class_weights[labels]
|
sample_weights = class_weights[labels]
|
||||||
|
@ -298,22 +314,24 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
|
||||||
return avg, median, std
|
return avg, median, std
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(
|
def train(
|
||||||
version_base="1.2",
|
cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None
|
||||||
config_path="../configs/policy",
|
) -> None:
|
||||||
config_name="hilserl_classifier",
|
if out_dir is None:
|
||||||
)
|
raise NotImplementedError()
|
||||||
def train(cfg: DictConfig) -> None:
|
if job_name is None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
# Main training pipeline with support for resuming training
|
# Main training pipeline with support for resuming training
|
||||||
|
init_logging()
|
||||||
logging.info(OmegaConf.to_yaml(cfg))
|
logging.info(OmegaConf.to_yaml(cfg))
|
||||||
|
|
||||||
|
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||||
|
|
||||||
# Initialize training environment
|
# Initialize training environment
|
||||||
device = get_safe_torch_device(cfg.device, log=True)
|
device = get_safe_torch_device(cfg.device, log=True)
|
||||||
set_global_seed(cfg.seed)
|
set_global_seed(cfg.seed)
|
||||||
|
|
||||||
out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "frozen_resnet10_2"
|
|
||||||
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
|
|
||||||
|
|
||||||
# Setup dataset and dataloaders
|
# Setup dataset and dataloaders
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
cfg.dataset_repo_id,
|
cfg.dataset_repo_id,
|
||||||
|
@ -462,5 +480,32 @@ def train(cfg: DictConfig) -> None:
|
||||||
logging.info("Training completed")
|
logging.info("Training completed")
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(
|
||||||
|
version_base="1.2",
|
||||||
|
config_name="hilserl_classifier",
|
||||||
|
config_path="../configs/policy",
|
||||||
|
)
|
||||||
|
def train_cli(cfg: dict):
|
||||||
|
train(
|
||||||
|
cfg,
|
||||||
|
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||||
|
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def train_notebook(
|
||||||
|
out_dir=None,
|
||||||
|
job_name=None,
|
||||||
|
config_name="hilserl_classifier",
|
||||||
|
config_path="../configs/policy",
|
||||||
|
):
|
||||||
|
from hydra import compose, initialize
|
||||||
|
|
||||||
|
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||||
|
initialize(config_path=config_path)
|
||||||
|
cfg = compose(config_name=config_name)
|
||||||
|
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train()
|
train_cli()
|
||||||
|
|
Loading…
Reference in New Issue