[PORT-Hilserl] classifier fixes (#695)
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
parent
3c58867738
commit
4057904238
|
@ -118,7 +118,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
|
||||||
pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"})
|
pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"})
|
||||||
|
|
||||||
|
|
||||||
def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_log=8):
|
def validate(model, val_loader, criterion, device, logger, cfg):
|
||||||
# Validation loop with metric tracking and sample logging
|
# Validation loop with metric tracking and sample logging
|
||||||
model.eval()
|
model.eval()
|
||||||
correct = 0
|
correct = 0
|
||||||
|
@ -160,15 +160,15 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
|
||||||
running_loss += loss.item()
|
running_loss += loss.item()
|
||||||
|
|
||||||
# Log sample predictions for visualization
|
# Log sample predictions for visualization
|
||||||
if len(samples) < num_samples_to_log:
|
if len(samples) < cfg.eval.num_samples_to_log:
|
||||||
for i in range(min(num_samples_to_log - len(samples), len(images))):
|
for i in range(min( cfg.eval.num_samples_to_log - len(samples), len(images))):
|
||||||
if model.config.num_classes == 2:
|
if model.config.num_classes == 2:
|
||||||
confidence = round(outputs.probabilities[i].item(), 3)
|
confidence = round(outputs.probabilities[i].item(), 3)
|
||||||
else:
|
else:
|
||||||
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
|
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
|
||||||
samples.append(
|
samples.append(
|
||||||
{
|
{
|
||||||
"image": wandb.Image(images[i].cpu()),
|
**{f"image_{img_key}": wandb.Image(images[img_idx][i].cpu()) for img_idx, img_key in enumerate(cfg.training.image_keys)},
|
||||||
"true_label": labels[i].item(),
|
"true_label": labels[i].item(),
|
||||||
"predicted": predictions[i].item(),
|
"predicted": predictions[i].item(),
|
||||||
"confidence": confidence,
|
"confidence": confidence,
|
||||||
|
@ -184,8 +184,8 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
|
||||||
"accuracy": accuracy,
|
"accuracy": accuracy,
|
||||||
"eval_s": time.perf_counter() - batch_start_time,
|
"eval_s": time.perf_counter() - batch_start_time,
|
||||||
"eval/prediction_samples": wandb.Table(
|
"eval/prediction_samples": wandb.Table(
|
||||||
data=[[s["image"], s["true_label"], s["predicted"], f"{s['confidence']}"] for s in samples],
|
data=[list(s.values()) for s in samples],
|
||||||
columns=["Image", "True Label", "Predicted", "Confidence"],
|
columns=list(samples[0].keys()),
|
||||||
)
|
)
|
||||||
if logger._cfg.wandb.enable
|
if logger._cfg.wandb.enable
|
||||||
else None,
|
else None,
|
||||||
|
@ -270,19 +270,18 @@ def train(cfg: DictConfig) -> None:
|
||||||
device = get_safe_torch_device(cfg.device, log=True)
|
device = get_safe_torch_device(cfg.device, log=True)
|
||||||
set_global_seed(cfg.seed)
|
set_global_seed(cfg.seed)
|
||||||
|
|
||||||
out_dir = Path(cfg.output_dir)
|
out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "classifier"
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
|
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
|
||||||
|
|
||||||
# Setup dataset and dataloaders
|
# Setup dataset and dataloaders
|
||||||
dataset = LeRobotDataset(cfg.dataset_repo_id)
|
dataset = LeRobotDataset(cfg.dataset_repo_id)
|
||||||
logging.info(f"Dataset size: {len(dataset)}")
|
logging.info(f"Dataset size: {len(dataset)}")
|
||||||
|
|
||||||
train_size = int(cfg.train_split_proportion * len(dataset))
|
n_total = len(dataset)
|
||||||
# val_size = len(dataset) - train_size
|
n_train = int(cfg.train_split_proportion * len(dataset))
|
||||||
# train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
train_dataset = torch.utils.data.Subset(dataset, range(0, n_train))
|
||||||
train_dataset = dataset[:train_size]
|
val_dataset = torch.utils.data.Subset(dataset, range(n_train, n_total))
|
||||||
val_dataset = dataset[train_size:]
|
|
||||||
|
|
||||||
sampler = create_balanced_sampler(train_dataset, cfg)
|
sampler = create_balanced_sampler(train_dataset, cfg)
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
|
|
Loading…
Reference in New Issue