49 lines
935 B
YAML
49 lines
935 B
YAML
# @package _global_
|
|
|
|
defaults:
|
|
- _self_
|
|
|
|
seed: 13
|
|
dataset_repo_id: "dataset_repo_id"
|
|
train_split_proportion: 0.8
|
|
|
|
# Required by logger
|
|
env:
|
|
name: "classifier"
|
|
task: "binary_classification"
|
|
|
|
|
|
training:
|
|
num_epochs: 5
|
|
batch_size: 16
|
|
learning_rate: 1e-4
|
|
num_workers: 4
|
|
grad_clip_norm: 10
|
|
use_amp: true
|
|
log_freq: 1
|
|
eval_freq: 1 # How often to run validation (in epochs)
|
|
save_freq: 1 # How often to save checkpoints (in epochs)
|
|
save_checkpoint: true
|
|
image_key: "observation.images.phone"
|
|
label_key: "next.reward"
|
|
|
|
eval:
|
|
batch_size: 16
|
|
num_samples_to_log: 30 # Number of validation samples to log in the table
|
|
|
|
policy:
|
|
name: "hilserl/classifier"
|
|
model_name: "facebook/convnext-base-224"
|
|
model_type: "cnn"
|
|
|
|
wandb:
|
|
enable: false
|
|
project: "classifier-training"
|
|
entity: "wandb_entity"
|
|
job_name: "classifier_training_0"
|
|
disable_artifact: false
|
|
|
|
device: "mps"
|
|
resume: false
|
|
output_dir: "output"
|