[Port HIL-SERL] Balanced sampler function speed up and refactor to align with train.py (#715)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
s1lent4gnt 2025-03-12 10:35:30 +01:00 committed by AdilZouitine
parent db78fee9de
commit 83b2dc1219
2 changed files with 65 additions and 12 deletions

View File

@ -3,6 +3,14 @@
defaults:
- _self_
hydra:
run:
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
dir: outputs/train_hilserl_classifier/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${hydra.job.name}
job:
name: default
seed: 13
dataset_repo_id: aractingi/push_cube_square_light_reward_cropped_resized
# aractingi/push_cube_square_reward_1_cropped_resized

View File

@ -42,6 +42,7 @@ from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_hydra_config,
init_logging,
set_global_seed,
)
from lerobot.scripts.server.buffer import random_shift
@ -60,9 +61,24 @@ def get_model(cfg, logger): # noqa I001
def create_balanced_sampler(dataset, cfg):
# Creates a weighted sampler to handle class imbalance
# Get underlying dataset if using Subset
original_dataset = (
dataset.dataset if isinstance(dataset, torch.utils.data.Subset) else dataset
)
labels = torch.tensor([item[cfg.training.label_key] for item in dataset])
# Get indices if using Subset (for slicing)
indices = dataset.indices if isinstance(dataset, torch.utils.data.Subset) else None
# Get labels from Hugging Face dataset
if indices is not None:
# Get subset of labels using Hugging Face's select()
hf_subset = original_dataset.hf_dataset.select(indices)
labels = hf_subset[cfg.training.label_key]
else:
# Get all labels directly
labels = original_dataset.hf_dataset[cfg.training.label_key]
labels = torch.stack(labels)
_, counts = torch.unique(labels, return_counts=True)
class_weights = 1.0 / counts.float()
sample_weights = class_weights[labels]
@ -298,22 +314,24 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
return avg, median, std
@hydra.main(
version_base="1.2",
config_path="../configs/policy",
config_name="hilserl_classifier",
)
def train(cfg: DictConfig) -> None:
def train(
cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None
) -> None:
if out_dir is None:
raise NotImplementedError()
if job_name is None:
raise NotImplementedError()
# Main training pipeline with support for resuming training
init_logging()
logging.info(OmegaConf.to_yaml(cfg))
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
# Initialize training environment
device = get_safe_torch_device(cfg.device, log=True)
set_global_seed(cfg.seed)
out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "frozen_resnet10_2"
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
# Setup dataset and dataloaders
dataset = LeRobotDataset(
cfg.dataset_repo_id,
@ -462,5 +480,32 @@ def train(cfg: DictConfig) -> None:
logging.info("Training completed")
@hydra.main(
version_base="1.2",
config_name="hilserl_classifier",
config_path="../configs/policy",
)
def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
def train_notebook(
out_dir=None,
job_name=None,
config_name="hilserl_classifier",
config_path="../configs/policy",
):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=config_path)
cfg = compose(config_name=config_name)
train(cfg, out_dir=out_dir, job_name=job_name)
if __name__ == "__main__":
train()
train_cli()