import os import tempfile from pathlib import Path from unittest.mock import MagicMock, patch import pytest import torch from hydra import compose, initialize_config_dir from torch import nn from torch.utils.data import Dataset from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier from lerobot.scripts.train_hilserl_classifier import ( create_balanced_sampler, train, train_epoch, validate, ) class MockDataset(Dataset): def __init__(self, data): self.data = data self.meta = MagicMock() self.meta.stats = {} def __getitem__(self, idx): return self.data[idx] def __len__(self): return len(self.data) def make_dummy_model(): model_config = ClassifierConfig( num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=1 ) model = Classifier(config=model_config) return model def test_create_balanced_sampler(): # Mock dataset with imbalanced classes data = [ {"label": 0}, {"label": 0}, {"label": 1}, {"label": 0}, {"label": 1}, {"label": 1}, {"label": 1}, {"label": 1}, ] dataset = MockDataset(data) cfg = MagicMock() cfg.training.label_key = "label" sampler = create_balanced_sampler(dataset, cfg) # Get weights from the sampler weights = sampler.weights.float() # Check that samples have appropriate weights labels = [item["label"] for item in data] class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32) class_weights = 1.0 / class_counts expected_weights = torch.tensor([class_weights[label] for label in labels], dtype=torch.float32) # Test that the weights are correct assert torch.allclose(weights, expected_weights) def test_train_epoch(): model = make_dummy_model() # Mock components model.train = MagicMock() train_loader = [ { "image": torch.rand(2, 3, 224, 224), "label": torch.tensor([0.0, 1.0]), } ] criterion = nn.BCEWithLogitsLoss() optimizer = MagicMock() grad_scaler = MagicMock() device = torch.device("cpu") logger = MagicMock() step = 0 cfg = MagicMock() cfg.training.image_keys = ["image"] cfg.training.label_key = "label" cfg.training.use_amp = False # Call the function under test train_epoch( model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg, ) # Check that model.train() was called model.train.assert_called_once() # Check that optimizer.zero_grad() was called optimizer.zero_grad.assert_called() # Check that logger.log_dict was called logger.log_dict.assert_called() def test_validate(): model = make_dummy_model() # Mock components model.eval = MagicMock() val_loader = [ { "image": torch.rand(2, 3, 224, 224), "label": torch.tensor([0.0, 1.0]), } ] criterion = nn.BCEWithLogitsLoss() device = torch.device("cpu") logger = MagicMock() cfg = MagicMock() cfg.training.image_keys = ["image"] cfg.training.label_key = "label" cfg.training.use_amp = False # Call validate accuracy, eval_info = validate(model, val_loader, criterion, device, logger, cfg) # Check that model.eval() was called model.eval.assert_called_once() # Check accuracy/eval_info are calculated and of the correct type assert isinstance(accuracy, float) assert isinstance(eval_info, dict) def test_train_epoch_multiple_cameras(): model_config = ClassifierConfig( num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=2 ) model = Classifier(config=model_config) # Mock components model.train = MagicMock() train_loader = [ { "image_1": torch.rand(2, 3, 224, 224), "image_2": torch.rand(2, 3, 224, 224), "label": torch.tensor([0.0, 1.0]), } ] criterion = nn.BCEWithLogitsLoss() optimizer = MagicMock() grad_scaler = MagicMock() device = torch.device("cpu") logger = MagicMock() step = 0 cfg = MagicMock() cfg.training.image_keys = ["image_1", "image_2"] cfg.training.label_key = "label" cfg.training.use_amp = False # Call the function under test train_epoch( model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg, ) # Check that model.train() was called model.train.assert_called_once() # Check that optimizer.zero_grad() was called optimizer.zero_grad.assert_called() # Check that logger.log_dict was called logger.log_dict.assert_called() @pytest.mark.parametrize("resume", [True, False]) @patch("lerobot.scripts.train_hilserl_classifier.init_hydra_config") @patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_checkpoint_dir") @patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir") @patch("lerobot.scripts.train_hilserl_classifier.Logger") @patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset") @patch("lerobot.scripts.train_hilserl_classifier.get_model") def test_resume_function( mock_get_model, mock_dataset, mock_logger, mock_get_last_pretrained_model_dir, mock_get_last_checkpoint_dir, mock_init_hydra_config, resume, ): # Initialize Hydra test_file_dir = os.path.dirname(os.path.abspath(__file__)) config_dir = os.path.abspath(os.path.join(test_file_dir, "..", "lerobot", "configs", "policy")) assert os.path.exists(config_dir), f"Config directory does not exist at {config_dir}" with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"): cfg = compose( config_name="hilserl_classifier", overrides=[ "device=cpu", "seed=42", f"output_dir={tempfile.mkdtemp()}", "wandb.enable=False", f"resume={resume}", "dataset_repo_id=dataset_repo_id", "train_split_proportion=0.8", "training.num_workers=0", "training.batch_size=2", "training.image_keys=[image]", "training.label_key=label", "training.use_amp=False", "training.num_epochs=1", "eval.batch_size=2", ], ) # Mock the init_hydra_config function to return cfg mock_init_hydra_config.return_value = cfg # Mock dataset dataset = MockDataset([{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)]) mock_dataset.return_value = dataset # Mock checkpoint handling mock_checkpoint_dir = MagicMock(spec=Path) mock_checkpoint_dir.exists.return_value = resume # Only exists if resuming mock_get_last_checkpoint_dir.return_value = mock_checkpoint_dir mock_get_last_pretrained_model_dir.return_value = Path(tempfile.mkdtemp()) # Mock logger logger = MagicMock() resumed_step = 1000 if resume: logger.load_last_training_state.return_value = resumed_step else: logger.load_last_training_state.return_value = 0 mock_logger.return_value = logger # Instantiate the model and set make_policy to return it model = make_dummy_model() mock_get_model.return_value = model # Call train train(cfg) # Check that checkpoint handling methods were called if resume: mock_get_last_checkpoint_dir.assert_called_once_with(Path(cfg.output_dir)) mock_get_last_pretrained_model_dir.assert_called_once_with(Path(cfg.output_dir)) mock_checkpoint_dir.exists.assert_called_once() logger.load_last_training_state.assert_called_once() else: mock_get_last_checkpoint_dir.assert_not_called() mock_get_last_pretrained_model_dir.assert_not_called() mock_checkpoint_dir.exists.assert_not_called() logger.load_last_training_state.assert_not_called() # Collect the steps from logger.log_dict calls train_log_calls = logger.log_dict.call_args_list # Extract the steps used in the train logging steps = [] for call in train_log_calls: mode = call.kwargs.get("mode", call.args[2] if len(call.args) > 2 else None) if mode == "train": step = call.kwargs.get("step", call.args[1] if len(call.args) > 1 else None) steps.append(step) expected_start_step = resumed_step if resume else 0 # Calculate expected_steps train_size = int(cfg.train_split_proportion * len(dataset)) batch_size = cfg.training.batch_size num_batches = (train_size + batch_size - 1) // batch_size expected_steps = [expected_start_step + i for i in range(num_batches)] assert steps == expected_steps, f"Expected steps {expected_steps}, got {steps}"