diff --git a/lerobot/configs/policy/hilserl_classifier.yaml b/lerobot/configs/policy/hilserl_classifier.yaml index 149eeab2..9ab181d5 100644 --- a/lerobot/configs/policy/hilserl_classifier.yaml +++ b/lerobot/configs/policy/hilserl_classifier.yaml @@ -3,6 +3,14 @@ defaults: - _self_ +hydra: + run: + # Set `dir` to where you would like to save all of the run outputs. If you run another training session + # with the same value for `dir` its contents will be overwritten unless you set `resume` to true. + dir: outputs/train_hilserl_classifier/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${hydra.job.name} + job: + name: default + seed: 13 dataset_repo_id: aractingi/push_cube_square_light_reward_cropped_resized # aractingi/push_cube_square_reward_1_cropped_resized diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index 6044b038..1cae0183 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -42,6 +42,7 @@ from lerobot.common.utils.utils import ( format_big_number, get_safe_torch_device, init_hydra_config, + init_logging, set_global_seed, ) from lerobot.scripts.server.buffer import random_shift @@ -60,9 +61,24 @@ def get_model(cfg, logger): # noqa I001 def create_balanced_sampler(dataset, cfg): - # Creates a weighted sampler to handle class imbalance + # Get underlying dataset if using Subset + original_dataset = ( + dataset.dataset if isinstance(dataset, torch.utils.data.Subset) else dataset + ) - labels = torch.tensor([item[cfg.training.label_key] for item in dataset]) + # Get indices if using Subset (for slicing) + indices = dataset.indices if isinstance(dataset, torch.utils.data.Subset) else None + + # Get labels from Hugging Face dataset + if indices is not None: + # Get subset of labels using Hugging Face's select() + hf_subset = original_dataset.hf_dataset.select(indices) + labels = hf_subset[cfg.training.label_key] + else: + # Get all labels directly + labels = original_dataset.hf_dataset[cfg.training.label_key] + + labels = torch.stack(labels) _, counts = torch.unique(labels, return_counts=True) class_weights = 1.0 / counts.float() sample_weights = class_weights[labels] @@ -298,22 +314,24 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step): return avg, median, std -@hydra.main( - version_base="1.2", - config_path="../configs/policy", - config_name="hilserl_classifier", -) -def train(cfg: DictConfig) -> None: +def train( + cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None +) -> None: + if out_dir is None: + raise NotImplementedError() + if job_name is None: + raise NotImplementedError() + # Main training pipeline with support for resuming training + init_logging() logging.info(OmegaConf.to_yaml(cfg)) + logger = Logger(cfg, out_dir, wandb_job_name=job_name) + # Initialize training environment device = get_safe_torch_device(cfg.device, log=True) set_global_seed(cfg.seed) - out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "frozen_resnet10_2" - logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None) - # Setup dataset and dataloaders dataset = LeRobotDataset( cfg.dataset_repo_id, @@ -462,5 +480,32 @@ def train(cfg: DictConfig) -> None: logging.info("Training completed") +@hydra.main( + version_base="1.2", + config_name="hilserl_classifier", + config_path="../configs/policy", +) +def train_cli(cfg: dict): + train( + cfg, + out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir, + job_name=hydra.core.hydra_config.HydraConfig.get().job.name, + ) + + +def train_notebook( + out_dir=None, + job_name=None, + config_name="hilserl_classifier", + config_path="../configs/policy", +): + from hydra import compose, initialize + + hydra.core.global_hydra.GlobalHydra.instance().clear() + initialize(config_path=config_path) + cfg = compose(config_name=config_name) + train(cfg, out_dir=out_dir, job_name=job_name) + + if __name__ == "__main__": - train() + train_cli()