[PORT-Hilserl] classifier fixes (#695)

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Yoel 2025-02-11 11:39:17 +01:00 committed by AdilZouitine
parent 3c58867738
commit 4057904238
1 changed files with 12 additions and 13 deletions

View File

@ -118,7 +118,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"}) pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"})
def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_log=8): def validate(model, val_loader, criterion, device, logger, cfg):
# Validation loop with metric tracking and sample logging # Validation loop with metric tracking and sample logging
model.eval() model.eval()
correct = 0 correct = 0
@ -160,15 +160,15 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
running_loss += loss.item() running_loss += loss.item()
# Log sample predictions for visualization # Log sample predictions for visualization
if len(samples) < num_samples_to_log: if len(samples) < cfg.eval.num_samples_to_log:
for i in range(min(num_samples_to_log - len(samples), len(images))): for i in range(min( cfg.eval.num_samples_to_log - len(samples), len(images))):
if model.config.num_classes == 2: if model.config.num_classes == 2:
confidence = round(outputs.probabilities[i].item(), 3) confidence = round(outputs.probabilities[i].item(), 3)
else: else:
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()] confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
samples.append( samples.append(
{ {
"image": wandb.Image(images[i].cpu()), **{f"image_{img_key}": wandb.Image(images[img_idx][i].cpu()) for img_idx, img_key in enumerate(cfg.training.image_keys)},
"true_label": labels[i].item(), "true_label": labels[i].item(),
"predicted": predictions[i].item(), "predicted": predictions[i].item(),
"confidence": confidence, "confidence": confidence,
@ -184,8 +184,8 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
"accuracy": accuracy, "accuracy": accuracy,
"eval_s": time.perf_counter() - batch_start_time, "eval_s": time.perf_counter() - batch_start_time,
"eval/prediction_samples": wandb.Table( "eval/prediction_samples": wandb.Table(
data=[[s["image"], s["true_label"], s["predicted"], f"{s['confidence']}"] for s in samples], data=[list(s.values()) for s in samples],
columns=["Image", "True Label", "Predicted", "Confidence"], columns=list(samples[0].keys()),
) )
if logger._cfg.wandb.enable if logger._cfg.wandb.enable
else None, else None,
@ -270,19 +270,18 @@ def train(cfg: DictConfig) -> None:
device = get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.device, log=True)
set_global_seed(cfg.seed) set_global_seed(cfg.seed)
out_dir = Path(cfg.output_dir) out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "classifier"
out_dir.mkdir(parents=True, exist_ok=True)
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None) logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
# Setup dataset and dataloaders # Setup dataset and dataloaders
dataset = LeRobotDataset(cfg.dataset_repo_id) dataset = LeRobotDataset(cfg.dataset_repo_id)
logging.info(f"Dataset size: {len(dataset)}") logging.info(f"Dataset size: {len(dataset)}")
train_size = int(cfg.train_split_proportion * len(dataset)) n_total = len(dataset)
# val_size = len(dataset) - train_size n_train = int(cfg.train_split_proportion * len(dataset))
# train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) train_dataset = torch.utils.data.Subset(dataset, range(0, n_train))
train_dataset = dataset[:train_size] val_dataset = torch.utils.data.Subset(dataset, range(n_train, n_total))
val_dataset = dataset[train_size:]
sampler = create_balanced_sampler(train_dataset, cfg) sampler = create_balanced_sampler(train_dataset, cfg)
train_loader = DataLoader( train_loader = DataLoader(