[PORT-Hilserl] classifier fixes (#695)
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
d8f35e9ce9
commit
7b295e159e
|
@ -118,7 +118,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
|
|||
pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"})
|
||||
|
||||
|
||||
def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_log=8):
|
||||
def validate(model, val_loader, criterion, device, logger, cfg):
|
||||
# Validation loop with metric tracking and sample logging
|
||||
model.eval()
|
||||
correct = 0
|
||||
|
@ -160,15 +160,15 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
|
|||
running_loss += loss.item()
|
||||
|
||||
# Log sample predictions for visualization
|
||||
if len(samples) < num_samples_to_log:
|
||||
for i in range(min(num_samples_to_log - len(samples), len(images))):
|
||||
if len(samples) < cfg.eval.num_samples_to_log:
|
||||
for i in range(min( cfg.eval.num_samples_to_log - len(samples), len(images))):
|
||||
if model.config.num_classes == 2:
|
||||
confidence = round(outputs.probabilities[i].item(), 3)
|
||||
else:
|
||||
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
|
||||
samples.append(
|
||||
{
|
||||
"image": wandb.Image(images[i].cpu()),
|
||||
**{f"image_{img_key}": wandb.Image(images[img_idx][i].cpu()) for img_idx, img_key in enumerate(cfg.training.image_keys)},
|
||||
"true_label": labels[i].item(),
|
||||
"predicted": predictions[i].item(),
|
||||
"confidence": confidence,
|
||||
|
@ -184,8 +184,8 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
|
|||
"accuracy": accuracy,
|
||||
"eval_s": time.perf_counter() - batch_start_time,
|
||||
"eval/prediction_samples": wandb.Table(
|
||||
data=[[s["image"], s["true_label"], s["predicted"], f"{s['confidence']}"] for s in samples],
|
||||
columns=["Image", "True Label", "Predicted", "Confidence"],
|
||||
data=[list(s.values()) for s in samples],
|
||||
columns=list(samples[0].keys()),
|
||||
)
|
||||
if logger._cfg.wandb.enable
|
||||
else None,
|
||||
|
@ -270,19 +270,18 @@ def train(cfg: DictConfig) -> None:
|
|||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
out_dir = Path(cfg.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "classifier"
|
||||
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
|
||||
|
||||
# Setup dataset and dataloaders
|
||||
dataset = LeRobotDataset(cfg.dataset_repo_id)
|
||||
logging.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
train_size = int(cfg.train_split_proportion * len(dataset))
|
||||
# val_size = len(dataset) - train_size
|
||||
# train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||
train_dataset = dataset[:train_size]
|
||||
val_dataset = dataset[train_size:]
|
||||
n_total = len(dataset)
|
||||
n_train = int(cfg.train_split_proportion * len(dataset))
|
||||
train_dataset = torch.utils.data.Subset(dataset, range(0, n_train))
|
||||
val_dataset = torch.utils.data.Subset(dataset, range(n_train, n_total))
|
||||
|
||||
|
||||
sampler = create_balanced_sampler(train_dataset, cfg)
|
||||
train_loader = DataLoader(
|
||||
|
|
Loading…
Reference in New Issue