Extend reward classifier for multiple camera views (#626)

This commit is contained in:
Michel Aractingi 2025-01-13 13:57:49 +01:00 committed by GitHub
parent c5bca1cf0f
commit 3bb5ed5e91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 192 additions and 49 deletions

View File

@ -25,13 +25,13 @@ from glob import glob
from pathlib import Path from pathlib import Path
import torch import torch
import wandb
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from termcolor import colored from termcolor import colored
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
import wandb
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state from lerobot.common.utils.utils import get_global_random_state, set_global_random_state

View File

@ -13,6 +13,7 @@ class ClassifierConfig:
model_name: str = "microsoft/resnet-50" model_name: str = "microsoft/resnet-50"
device: str = "cpu" device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn" model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2
def save_pretrained(self, save_dir): def save_pretrained(self, save_dir):
"""Save config to json file.""" """Save config to json file."""

View File

@ -97,7 +97,7 @@ class Classifier(
raise ValueError("Unsupported transformer architecture since hidden_size is not found") raise ValueError("Unsupported transformer architecture since hidden_size is not found")
self.classifier_head = nn.Sequential( self.classifier_head = nn.Sequential(
nn.Linear(input_dim, self.config.hidden_dim), nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
nn.Dropout(self.config.dropout_rate), nn.Dropout(self.config.dropout_rate),
nn.LayerNorm(self.config.hidden_dim), nn.LayerNorm(self.config.hidden_dim),
nn.ReLU(), nn.ReLU(),
@ -130,11 +130,11 @@ class Classifier(
return outputs.pooler_output return outputs.pooler_output
return outputs.last_hidden_state[:, 0, :] return outputs.last_hidden_state[:, 0, :]
def forward(self, x: torch.Tensor) -> ClassifierOutput: def forward(self, xs: torch.Tensor) -> ClassifierOutput:
"""Forward pass of the classifier.""" """Forward pass of the classifier."""
# For training, we expect input to be a tensor directly from LeRobotDataset # For training, we expect input to be a tensor directly from LeRobotDataset
encoder_output = self._get_encoder_output(x) encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs])
logits = self.classifier_head(encoder_output) logits = self.classifier_head(encoder_outputs)
if self.config.num_classes == 2: if self.config.num_classes == 2:
logits = logits.squeeze(-1) logits = logits.squeeze(-1)
@ -142,4 +142,10 @@ class Classifier(
else: else:
probabilities = torch.softmax(logits, dim=-1) probabilities = torch.softmax(logits, dim=-1)
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_output) return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
def predict_reward(self, x):
if self.config.num_classes == 2:
return (self.forward(x).probabilities > 0.5).float()
else:
return torch.argmax(self.forward(x).probabilities, dim=1)

View File

@ -11,6 +11,7 @@ from copy import copy
from functools import cache from functools import cache
import cv2 import cv2
import numpy as np
import torch import torch
import tqdm import tqdm
from deepdiff import DeepDiff from deepdiff import DeepDiff
@ -332,6 +333,14 @@ def reset_environment(robot, events, reset_time_s):
break break
def reset_follower_position(robot: Robot, target_position):
current_position = robot.follower_arms["main"].read("Present_Position")
trajectory = torch.from_numpy(np.linspace(current_position, target_position, 30)) # NOTE: 30 is just an aribtrary number
for pose in trajectory:
robot.send_action(pose)
busy_wait(0.015)
def stop_recording(robot, listener, display_cameras): def stop_recording(robot, listener, display_cameras):
robot.disconnect() robot.disconnect()

View File

@ -4,7 +4,7 @@ defaults:
- _self_ - _self_
seed: 13 seed: 13
dataset_repo_id: "dataset_repo_id" dataset_repo_id: aractingi/pick_place_lego_cube_1
train_split_proportion: 0.8 train_split_proportion: 0.8
# Required by logger # Required by logger
@ -24,7 +24,7 @@ training:
eval_freq: 1 # How often to run validation (in epochs) eval_freq: 1 # How often to run validation (in epochs)
save_freq: 1 # How often to save checkpoints (in epochs) save_freq: 1 # How often to save checkpoints (in epochs)
save_checkpoint: true save_checkpoint: true
image_key: "observation.images.phone" image_keys: ["observation.images.top", "observation.images.wrist"]
label_key: "next.reward" label_key: "next.reward"
eval: eval:
@ -32,9 +32,10 @@ eval:
num_samples_to_log: 30 # Number of validation samples to log in the table num_samples_to_log: 30 # Number of validation samples to log in the table
policy: policy:
name: "hilserl/classifier" name: "hilserl/classifier/pick_place_lego_cube_1"
model_name: "facebook/convnext-base-224" model_name: "facebook/convnext-base-224"
model_type: "cnn" model_type: "cnn"
num_cameras: 2 # Has to be len(training.image_keys)
wandb: wandb:
enable: false enable: false
@ -44,4 +45,4 @@ wandb:
device: "mps" device: "mps"
resume: false resume: false
output_dir: "output" output_dir: "outputs/classifier"

View File

@ -109,6 +109,7 @@ from lerobot.common.robot_devices.control_utils import (
log_control_info, log_control_info,
record_episode, record_episode,
reset_environment, reset_environment,
reset_follower_position,
sanity_check_dataset_name, sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility, sanity_check_dataset_robot_compatibility,
stop_recording, stop_recording,
@ -205,6 +206,7 @@ def record(
num_image_writer_threads_per_camera: int = 4, num_image_writer_threads_per_camera: int = 4,
display_cameras: bool = True, display_cameras: bool = True,
play_sounds: bool = True, play_sounds: bool = True,
reset_follower: bool = False,
resume: bool = False, resume: bool = False,
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument # TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
local_files_only: bool = False, local_files_only: bool = False,
@ -265,6 +267,9 @@ def record(
robot.connect() robot.connect()
listener, events = init_keyboard_listener(assign_rewards=assign_rewards) listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
if reset_follower:
initial_position = robot.follower_arms["main"].read("Present_Position")
# Execute a few seconds without recording to: # Execute a few seconds without recording to:
# 1. teleoperate the robot to move it in starting position if no policy provided, # 1. teleoperate the robot to move it in starting position if no policy provided,
# 2. give times to the robot devices to connect and start synchronizing, # 2. give times to the robot devices to connect and start synchronizing,
@ -307,6 +312,8 @@ def record(
(dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"] (dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"]
): ):
log_say("Reset the environment", play_sounds) log_say("Reset the environment", play_sounds)
if reset_follower:
reset_follower_position(robot, initial_position)
reset_environment(robot, events, reset_time_s) reset_environment(robot, events, reset_time_s)
if events["rerecord_episode"]: if events["rerecord_episode"]:
@ -527,6 +534,12 @@ if __name__ == "__main__":
default=0, default=0,
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.", help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
) )
parser_record.add_argument(
"--reset-follower",
type=int,
default=0,
help="Resets the follower to the initial position during while reseting the evironment, this is to avoid having the follower start at an awkward position in the next episode",
)
parser_replay = subparsers.add_parser("replay", parents=[base_parser]) parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument( parser_replay.add_argument(

View File

@ -23,6 +23,15 @@ python lerobot/scripts/eval_on_robot.py \
eval.n_episodes=10 eval.n_episodes=10
``` ```
Test reward classifier with teleoperation (you need to press space to take over)
```
python lerobot/scripts/eval_on_robot.py \
--robot-path lerobot/configs/robot/so100.yaml \
--reward-classifier-pretrained-path outputs/classifier/checkpoints/best/pretrained_model \
--reward-classifier-config-file lerobot/configs/policy/hilserl_classifier.yaml \
--display-cameras 1
```
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared **NOTE** (michel-aractingi): This script is incomplete and it is being prepared
for running training on the real robot. for running training on the real robot.
""" """
@ -30,14 +39,14 @@ for running training on the real robot.
import argparse import argparse
import logging import logging
import time import time
from copy import deepcopy
import cv2
import numpy as np import numpy as np
import torch import torch
from tqdm import trange from tqdm import trange
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
from lerobot.common.robot_devices.robots.factory import Robot, make_robot from lerobot.common.robot_devices.robots.factory import Robot, make_robot
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
init_hydra_config, init_hydra_config,
@ -46,7 +55,33 @@ from lerobot.common.utils.utils import (
) )
def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict: def get_classifier(pretrained_path, config_path):
if pretrained_path is None or config_path is None:
return
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to("mps")
return model
def rollout(
robot: Robot,
policy: Policy,
reward_classifier,
fps: int,
control_time_s: float = 20,
use_amp: bool = True,
display_cameras: bool = False,
) -> dict:
"""Run a batched policy rollout on the real robot. """Run a batched policy rollout on the real robot.
The return dictionary contains: The return dictionary contains:
@ -70,6 +105,7 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
Returns: Returns:
The dictionary described above. The dictionary described above.
""" """
# TODO (michel-aractingi): Infer the device from policy parameters when policy is added
# assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module." # assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
# device = get_device_from_parameters(policy) # device = get_device_from_parameters(policy)
@ -79,25 +115,21 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready. # Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
# policy.reset() # policy.reset()
# Get observation from real robot # NOTE: sorting to make sure the key sequence is the same during training and testing.
observation = robot.capture_observation() observation = robot.capture_observation()
image_keys = [key for key in observation if "image" in key]
image_keys.sort()
# Calculate reward. TODO (michel-aractingi)
# in HIL-SERL it will be with a reward classifier
reward = calculate_reward(observation)
all_observations = []
all_actions = [] all_actions = []
all_rewards = [] all_rewards = []
all_successes = [] all_successes = []
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
init_pos = robot.follower_arms["main"].read("Present_Position")
timestamp = 0.0 timestamp = 0.0
while timestamp < control_time_s: while timestamp < control_time_s:
start_loop_t = time.perf_counter() start_loop_t = time.perf_counter()
all_observations.append(deepcopy(observation))
# observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
# Apply the next action. # Apply the next action.
while events["pause_policy"] and not events["human_intervention_step"]: while events["pause_policy"] and not events["human_intervention_step"]:
busy_wait(0.5) busy_wait(0.5)
@ -109,18 +141,26 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
else: else:
# explore with policy # explore with policy
with torch.inference_mode(): with torch.inference_mode():
# TODO (michel-aractingi) replace this part with policy (predict_action)
action = robot.follower_arms["main"].read("Present_Position") action = robot.follower_arms["main"].read("Present_Position")
action = torch.from_numpy(action) action = torch.from_numpy(action)
robot.send_action(action) robot.send_action(action)
# action = predict_action(observation, policy, device, use_amp) # action = predict_action(observation, policy, device, use_amp)
observation = robot.capture_observation() observation = robot.capture_observation()
# Calculate reward images = []
# in HIL-SERL it will be with a reward classifier for key in image_keys:
reward = calculate_reward(observation) if display_cameras:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
images.append(observation[key].to("mps"))
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
all_rewards.append(reward)
# print("REWARD : ", reward)
all_actions.append(action) all_actions.append(action)
all_rewards.append(torch.from_numpy(reward))
all_successes.append(torch.tensor([False])) all_successes.append(torch.tensor([False]))
dt_s = time.perf_counter() - start_loop_t dt_s = time.perf_counter() - start_loop_t
@ -131,7 +171,8 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
events["human_intervention_step"] = False events["human_intervention_step"] = False
events["pause_policy"] = False events["pause_policy"] = False
break break
all_observations.append(deepcopy(observation))
reset_follower_position(robot, target_position=init_pos)
dones = torch.tensor([False] * len(all_actions)) dones = torch.tensor([False] * len(all_actions))
dones[-1] = True dones[-1] = True
@ -142,10 +183,6 @@ def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20,
"next.success": torch.stack(all_successes, dim=1), "next.success": torch.stack(all_successes, dim=1),
"done": dones, "done": dones,
} }
stacked_observations = {}
for key in all_observations[0]:
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
ret["observation"] = stacked_observations
listener.stop() listener.stop()
@ -159,6 +196,9 @@ def eval_policy(
n_episodes: int, n_episodes: int,
control_time_s: int = 20, control_time_s: int = 20,
use_amp: bool = True, use_amp: bool = True,
display_cameras: bool = False,
reward_classifier_pretrained_path: str | None = None,
reward_classifier_config_file: str | None = None,
) -> dict: ) -> dict:
""" """
Args: Args:
@ -179,8 +219,12 @@ def eval_policy(
start_eval = time.perf_counter() start_eval = time.perf_counter()
progbar = trange(n_episodes, desc="Evaluating policy on real robot") progbar = trange(n_episodes, desc="Evaluating policy on real robot")
for _batch_idx in progbar: reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file)
rollout_data = rollout(robot, policy, fps, control_time_s, use_amp)
for _ in progbar:
rollout_data = rollout(
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras
)
rollouts.append(rollout_data) rollouts.append(rollout_data)
sum_rewards.append(sum(rollout_data["next.reward"])) sum_rewards.append(sum(rollout_data["next.reward"]))
@ -219,15 +263,6 @@ def eval_policy(
return info return info
def calculate_reward(observation):
"""
Method to calculate reward function in some way.
In HIL-SERL this is done through defining a reward classifier
"""
# reward = reward_classifier(observation)
return np.array([0.0])
def init_keyboard_listener(): def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment, # Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission # by tapping the right arrow key '->'. This might require a sudo permission
@ -324,6 +359,21 @@ if __name__ == "__main__":
"outputs/eval/{timestamp}_{env_name}_{policy_name}" "outputs/eval/{timestamp}_{env_name}_{policy_name}"
), ),
) )
parser.add_argument(
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
)
parser.add_argument(
"--reward-classifier-pretrained-path",
type=str,
default=None,
help="Path to the pretrained classifier weights.",
)
parser.add_argument(
"--reward-classifier-config-file",
type=str,
default=None,
help="Path to a yaml config file that is necessary to build the reward classifier model.",
)
args = parser.parse_args() args = parser.parse_args()
@ -332,4 +382,13 @@ if __name__ == "__main__":
if not robot.is_connected: if not robot.is_connected:
robot.connect() robot.connect()
eval_policy(robot, None, fps=40, n_episodes=2, control_time_s=100) eval_policy(
robot,
None,
fps=40,
n_episodes=2,
control_time_s=100,
display_cameras=args.display_cameras,
reward_classifier_config_file=args.reward_classifier_config_file,
reward_classifier_pretrained_path=args.reward_classifier_pretrained_path,
)

View File

@ -22,6 +22,7 @@ from pprint import pformat
import hydra import hydra
import torch import torch
import torch.nn as nn import torch.nn as nn
import wandb
from deepdiff import DeepDiff from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from termcolor import colored from termcolor import colored
@ -30,7 +31,6 @@ from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from tqdm import tqdm from tqdm import tqdm
import wandb
from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger from lerobot.common.logger import Logger
@ -79,7 +79,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
pbar = tqdm(train_loader, desc="Training") pbar = tqdm(train_loader, desc="Training")
for batch_idx, batch in enumerate(pbar): for batch_idx, batch in enumerate(pbar):
start_time = time.perf_counter() start_time = time.perf_counter()
images = batch[cfg.training.image_key].to(device) images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
labels = batch[cfg.training.label_key].float().to(device) labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP # Forward pass with optional AMP
@ -130,7 +130,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(), torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
): ):
for batch in tqdm(val_loader, desc="Validation"): for batch in tqdm(val_loader, desc="Validation"):
images = batch[cfg.training.image_key].to(device) images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
labels = batch[cfg.training.label_key].float().to(device) labels = batch[cfg.training.label_key].float().to(device)
outputs = model(images) outputs = model(images)
@ -163,6 +163,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
accuracy = 100 * correct / total accuracy = 100 * correct / total
avg_loss = running_loss / len(val_loader) avg_loss = running_loss / len(val_loader)
print(f"Average validation loss {avg_loss}, and accuracy {accuracy}")
eval_info = { eval_info = {
"loss": avg_loss, "loss": avg_loss,

View File

@ -33,7 +33,9 @@ class MockDataset(Dataset):
def make_dummy_model(): def make_dummy_model():
model_config = ClassifierConfig(num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel") model_config = ClassifierConfig(
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=1
)
model = Classifier(config=model_config) model = Classifier(config=model_config)
return model return model
@ -88,7 +90,7 @@ def test_train_epoch():
logger = MagicMock() logger = MagicMock()
step = 0 step = 0
cfg = MagicMock() cfg = MagicMock()
cfg.training.image_key = "image" cfg.training.image_keys = ["image"]
cfg.training.label_key = "label" cfg.training.label_key = "label"
cfg.training.use_amp = False cfg.training.use_amp = False
@ -130,7 +132,7 @@ def test_validate():
device = torch.device("cpu") device = torch.device("cpu")
logger = MagicMock() logger = MagicMock()
cfg = MagicMock() cfg = MagicMock()
cfg.training.image_key = "image" cfg.training.image_keys = ["image"]
cfg.training.label_key = "label" cfg.training.label_key = "label"
cfg.training.use_amp = False cfg.training.use_amp = False
@ -145,6 +147,57 @@ def test_validate():
assert isinstance(eval_info, dict) assert isinstance(eval_info, dict)
def test_train_epoch_multiple_cameras():
model_config = ClassifierConfig(
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=2
)
model = Classifier(config=model_config)
# Mock components
model.train = MagicMock()
train_loader = [
{
"image_1": torch.rand(2, 3, 224, 224),
"image_2": torch.rand(2, 3, 224, 224),
"label": torch.tensor([0.0, 1.0]),
}
]
criterion = nn.BCEWithLogitsLoss()
optimizer = MagicMock()
grad_scaler = MagicMock()
device = torch.device("cpu")
logger = MagicMock()
step = 0
cfg = MagicMock()
cfg.training.image_keys = ["image_1", "image_2"]
cfg.training.label_key = "label"
cfg.training.use_amp = False
# Call the function under test
train_epoch(
model,
train_loader,
criterion,
optimizer,
grad_scaler,
device,
logger,
step,
cfg,
)
# Check that model.train() was called
model.train.assert_called_once()
# Check that optimizer.zero_grad() was called
optimizer.zero_grad.assert_called()
# Check that logger.log_dict was called
logger.log_dict.assert_called()
@pytest.mark.parametrize("resume", [True, False]) @pytest.mark.parametrize("resume", [True, False])
@patch("lerobot.scripts.train_hilserl_classifier.init_hydra_config") @patch("lerobot.scripts.train_hilserl_classifier.init_hydra_config")
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_checkpoint_dir") @patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_checkpoint_dir")
@ -179,7 +232,7 @@ def test_resume_function(
"train_split_proportion=0.8", "train_split_proportion=0.8",
"training.num_workers=0", "training.num_workers=0",
"training.batch_size=2", "training.batch_size=2",
"training.image_key=image", "training.image_keys=[image]",
"training.label_key=label", "training.label_key=label",
"training.use_amp=False", "training.use_amp=False",
"training.num_epochs=1", "training.num_epochs=1",