diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index 15cf3d0b..0db19cd6 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -118,7 +118,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"}) -def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_log=8): +def validate(model, val_loader, criterion, device, logger, cfg): # Validation loop with metric tracking and sample logging model.eval() correct = 0 @@ -160,15 +160,15 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l running_loss += loss.item() # Log sample predictions for visualization - if len(samples) < num_samples_to_log: - for i in range(min(num_samples_to_log - len(samples), len(images))): + if len(samples) < cfg.eval.num_samples_to_log: + for i in range(min( cfg.eval.num_samples_to_log - len(samples), len(images))): if model.config.num_classes == 2: confidence = round(outputs.probabilities[i].item(), 3) else: confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()] samples.append( { - "image": wandb.Image(images[i].cpu()), + **{f"image_{img_key}": wandb.Image(images[img_idx][i].cpu()) for img_idx, img_key in enumerate(cfg.training.image_keys)}, "true_label": labels[i].item(), "predicted": predictions[i].item(), "confidence": confidence, @@ -184,8 +184,8 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l "accuracy": accuracy, "eval_s": time.perf_counter() - batch_start_time, "eval/prediction_samples": wandb.Table( - data=[[s["image"], s["true_label"], s["predicted"], f"{s['confidence']}"] for s in samples], - columns=["Image", "True Label", "Predicted", "Confidence"], + data=[list(s.values()) for s in samples], + columns=list(samples[0].keys()), ) if logger._cfg.wandb.enable else None, @@ -270,19 +270,18 @@ def train(cfg: DictConfig) -> None: device = get_safe_torch_device(cfg.device, log=True) set_global_seed(cfg.seed) - out_dir = Path(cfg.output_dir) - out_dir.mkdir(parents=True, exist_ok=True) + out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "classifier" logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None) # Setup dataset and dataloaders dataset = LeRobotDataset(cfg.dataset_repo_id) logging.info(f"Dataset size: {len(dataset)}") - train_size = int(cfg.train_split_proportion * len(dataset)) - # val_size = len(dataset) - train_size - # train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) - train_dataset = dataset[:train_size] - val_dataset = dataset[train_size:] + n_total = len(dataset) + n_train = int(cfg.train_split_proportion * len(dataset)) + train_dataset = torch.utils.data.Subset(dataset, range(0, n_train)) + val_dataset = torch.utils.data.Subset(dataset, range(n_train, n_total)) + sampler = create_balanced_sampler(train_dataset, cfg) train_loader = DataLoader(