[Port HIL-SERL] Balanced sampler function speed up and refactor to align with train.py (#715)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
db78fee9de
commit
83b2dc1219
|
@ -3,6 +3,14 @@
|
|||
defaults:
|
||||
- _self_
|
||||
|
||||
hydra:
|
||||
run:
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
dir: outputs/train_hilserl_classifier/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${hydra.job.name}
|
||||
job:
|
||||
name: default
|
||||
|
||||
seed: 13
|
||||
dataset_repo_id: aractingi/push_cube_square_light_reward_cropped_resized
|
||||
# aractingi/push_cube_square_reward_1_cropped_resized
|
||||
|
|
|
@ -42,6 +42,7 @@ from lerobot.common.utils.utils import (
|
|||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server.buffer import random_shift
|
||||
|
@ -60,9 +61,24 @@ def get_model(cfg, logger): # noqa I001
|
|||
|
||||
|
||||
def create_balanced_sampler(dataset, cfg):
|
||||
# Creates a weighted sampler to handle class imbalance
|
||||
# Get underlying dataset if using Subset
|
||||
original_dataset = (
|
||||
dataset.dataset if isinstance(dataset, torch.utils.data.Subset) else dataset
|
||||
)
|
||||
|
||||
labels = torch.tensor([item[cfg.training.label_key] for item in dataset])
|
||||
# Get indices if using Subset (for slicing)
|
||||
indices = dataset.indices if isinstance(dataset, torch.utils.data.Subset) else None
|
||||
|
||||
# Get labels from Hugging Face dataset
|
||||
if indices is not None:
|
||||
# Get subset of labels using Hugging Face's select()
|
||||
hf_subset = original_dataset.hf_dataset.select(indices)
|
||||
labels = hf_subset[cfg.training.label_key]
|
||||
else:
|
||||
# Get all labels directly
|
||||
labels = original_dataset.hf_dataset[cfg.training.label_key]
|
||||
|
||||
labels = torch.stack(labels)
|
||||
_, counts = torch.unique(labels, return_counts=True)
|
||||
class_weights = 1.0 / counts.float()
|
||||
sample_weights = class_weights[labels]
|
||||
|
@ -298,22 +314,24 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
|
|||
return avg, median, std
|
||||
|
||||
|
||||
@hydra.main(
|
||||
version_base="1.2",
|
||||
config_path="../configs/policy",
|
||||
config_name="hilserl_classifier",
|
||||
)
|
||||
def train(cfg: DictConfig) -> None:
|
||||
def train(
|
||||
cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None
|
||||
) -> None:
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Main training pipeline with support for resuming training
|
||||
init_logging()
|
||||
logging.info(OmegaConf.to_yaml(cfg))
|
||||
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
|
||||
# Initialize training environment
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "frozen_resnet10_2"
|
||||
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
|
||||
|
||||
# Setup dataset and dataloaders
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
|
@ -462,5 +480,32 @@ def train(cfg: DictConfig) -> None:
|
|||
logging.info("Training completed")
|
||||
|
||||
|
||||
@hydra.main(
|
||||
version_base="1.2",
|
||||
config_name="hilserl_classifier",
|
||||
config_path="../configs/policy",
|
||||
)
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
|
||||
def train_notebook(
|
||||
out_dir=None,
|
||||
job_name=None,
|
||||
config_name="hilserl_classifier",
|
||||
config_path="../configs/policy",
|
||||
):
|
||||
from hydra import compose, initialize
|
||||
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
initialize(config_path=config_path)
|
||||
cfg = compose(config_name=config_name)
|
||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
train_cli()
|
||||
|
|
Loading…
Reference in New Issue