50 lines
1.2 KiB
YAML
50 lines
1.2 KiB
YAML
# @package _global_
|
|
|
|
defaults:
|
|
- _self_
|
|
|
|
seed: 13
|
|
dataset_repo_id: aractingi/push_green_cube_hf_cropped_resized
|
|
train_split_proportion: 0.8
|
|
|
|
# Required by logger
|
|
env:
|
|
name: "classifier"
|
|
task: "binary_classification"
|
|
|
|
|
|
training:
|
|
num_epochs: 5
|
|
batch_size: 16
|
|
learning_rate: 1e-4
|
|
num_workers: 4
|
|
grad_clip_norm: 10
|
|
use_amp: true
|
|
log_freq: 1
|
|
eval_freq: 1 # How often to run validation (in epochs)
|
|
save_freq: 1 # How often to save checkpoints (in epochs)
|
|
save_checkpoint: true
|
|
# image_keys: ["observation.images.top", "observation.images.wrist"]
|
|
image_keys: ["observation.images.laptop", "observation.images.phone"]
|
|
label_key: "next.reward"
|
|
|
|
eval:
|
|
batch_size: 16
|
|
num_samples_to_log: 30 # Number of validation samples to log in the table
|
|
|
|
policy:
|
|
name: "hilserl/classifier/push_green_cube_hf_cropped_resized" #"hilserl/classifier/pick_place_lego_cube_1"
|
|
model_name: "facebook/convnext-base-224"
|
|
model_type: "cnn"
|
|
num_cameras: 2 # Has to be len(training.image_keys)
|
|
|
|
wandb:
|
|
enable: false
|
|
project: "classifier-training"
|
|
job_name: "classifier_training_0"
|
|
disable_artifact: false
|
|
|
|
device: "mps"
|
|
resume: false
|
|
output_dir: "outputs/classifier"
|