[Port HIL_SERL] Final fixes for the Reward Classifier (#598)
This commit is contained in:
parent
13441f0d98
commit
844bfcf484
|
@ -4,7 +4,6 @@ from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from transformers import AutoImageProcessor, AutoModel
|
|
||||||
|
|
||||||
from .configuration_classifier import ClassifierConfig
|
from .configuration_classifier import ClassifierConfig
|
||||||
|
|
||||||
|
@ -44,6 +43,8 @@ class Classifier(
|
||||||
name = "classifier"
|
name = "classifier"
|
||||||
|
|
||||||
def __init__(self, config: ClassifierConfig):
|
def __init__(self, config: ClassifierConfig):
|
||||||
|
from transformers import AutoImageProcessor, AutoModel
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||||
|
|
|
@ -327,7 +327,6 @@ class Critic(nn.Module):
|
||||||
value = self.output_layer(x)
|
value = self.output_layer(x)
|
||||||
return value.squeeze(-1)
|
return value.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
class Policy(nn.Module):
|
class Policy(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -362,12 +362,16 @@ def sanity_check_dataset_name(repo_id, policy):
|
||||||
|
|
||||||
|
|
||||||
def sanity_check_dataset_robot_compatibility(
|
def sanity_check_dataset_robot_compatibility(
|
||||||
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
|
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool, extra_features: dict = None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
features_from_robot = get_features_from_robot(robot, use_videos)
|
||||||
|
if extra_features is not None:
|
||||||
|
features_from_robot.update(extra_features)
|
||||||
|
|
||||||
fields = [
|
fields = [
|
||||||
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||||
("fps", dataset.fps, fps),
|
("fps", dataset.fps, fps),
|
||||||
("features", dataset.features, get_features_from_robot(robot, use_videos)),
|
("features", dataset.features, features_from_robot),
|
||||||
]
|
]
|
||||||
|
|
||||||
mismatches = []
|
mismatches = []
|
||||||
|
|
|
@ -39,7 +39,6 @@ policy:
|
||||||
wandb:
|
wandb:
|
||||||
enable: false
|
enable: false
|
||||||
project: "classifier-training"
|
project: "classifier-training"
|
||||||
entity: "wandb_entity"
|
|
||||||
job_name: "classifier_training_0"
|
job_name: "classifier_training_0"
|
||||||
disable_artifact: false
|
disable_artifact: false
|
||||||
|
|
||||||
|
|
|
@ -246,7 +246,7 @@ def record(
|
||||||
num_processes=num_image_writer_processes,
|
num_processes=num_image_writer_processes,
|
||||||
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||||
)
|
)
|
||||||
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
|
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video, extra_features)
|
||||||
else:
|
else:
|
||||||
# Create empty dataset or load existing saved episodes
|
# Create empty dataset or load existing saved episodes
|
||||||
sanity_check_dataset_name(repo_id, policy)
|
sanity_check_dataset_name(repo_id, policy)
|
||||||
|
|
|
@ -183,8 +183,14 @@ def record(
|
||||||
resume: bool = False,
|
resume: bool = False,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
run_compute_stats: bool = True,
|
run_compute_stats: bool = True,
|
||||||
|
assign_rewards: bool = False,
|
||||||
) -> LeRobotDataset:
|
) -> LeRobotDataset:
|
||||||
# Load pretrained policy
|
# Load pretrained policy
|
||||||
|
|
||||||
|
extra_features = (
|
||||||
|
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
|
||||||
|
)
|
||||||
|
|
||||||
policy = None
|
policy = None
|
||||||
if pretrained_policy_name_or_path is not None:
|
if pretrained_policy_name_or_path is not None:
|
||||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||||
|
@ -197,7 +203,7 @@ def record(
|
||||||
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
|
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
|
||||||
|
|
||||||
# initialize listener before sim env
|
# initialize listener before sim env
|
||||||
listener, events = init_keyboard_listener()
|
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
|
||||||
|
|
||||||
# create sim env
|
# create sim env
|
||||||
env = env()
|
env = env()
|
||||||
|
@ -237,6 +243,7 @@ def record(
|
||||||
}
|
}
|
||||||
|
|
||||||
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
|
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
|
||||||
|
features = {**features, **extra_features}
|
||||||
|
|
||||||
# Create empty dataset or load existing saved episodes
|
# Create empty dataset or load existing saved episodes
|
||||||
sanity_check_dataset_name(repo_id, policy)
|
sanity_check_dataset_name(repo_id, policy)
|
||||||
|
@ -288,6 +295,13 @@ def record(
|
||||||
"timestamp": env_timestamp,
|
"timestamp": env_timestamp,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Overwrite environment reward with manually assigned reward
|
||||||
|
if assign_rewards:
|
||||||
|
frame["next.reward"] = events["next.reward"]
|
||||||
|
|
||||||
|
# Should success always be false to match what we do in control_utils?
|
||||||
|
frame["next.success"] = False
|
||||||
|
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
if not key.startswith("observation.image"):
|
if not key.startswith("observation.image"):
|
||||||
frame["observation.image." + key] = observation[key]
|
frame["observation.image." + key] = observation[key]
|
||||||
|
@ -472,6 +486,13 @@ if __name__ == "__main__":
|
||||||
default=0,
|
default=0,
|
||||||
help="Resume recording on an existing dataset.",
|
help="Resume recording on an existing dataset.",
|
||||||
)
|
)
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--assign-rewards",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
|
||||||
|
)
|
||||||
|
|
||||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||||
parser_replay.add_argument(
|
parser_replay.add_argument(
|
||||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||||
|
|
|
@ -45,7 +45,7 @@ from lerobot.common.utils.utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_model(cfg, logger):
|
def get_model(cfg, logger): # noqa I001
|
||||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||||
model = Classifier(classifier_config)
|
model = Classifier(classifier_config)
|
||||||
if cfg.resume:
|
if cfg.resume:
|
||||||
|
@ -64,6 +64,12 @@ def create_balanced_sampler(dataset, cfg):
|
||||||
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
||||||
|
|
||||||
|
|
||||||
|
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
|
||||||
|
# Check if the device supports AMP
|
||||||
|
# Here is an example of the issue that says that MPS doesn't support AMP properply
|
||||||
|
return cfg.training.use_amp and device.type in ("cuda", "cpu")
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
|
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
|
||||||
# Single epoch training loop with AMP support and progress tracking
|
# Single epoch training loop with AMP support and progress tracking
|
||||||
model.train()
|
model.train()
|
||||||
|
@ -77,7 +83,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
|
||||||
labels = batch[cfg.training.label_key].float().to(device)
|
labels = batch[cfg.training.label_key].float().to(device)
|
||||||
|
|
||||||
# Forward pass with optional AMP
|
# Forward pass with optional AMP
|
||||||
with torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
|
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
|
||||||
outputs = model(images)
|
outputs = model(images)
|
||||||
loss = criterion(outputs.logits, labels)
|
loss = criterion(outputs.logits, labels)
|
||||||
|
|
||||||
|
@ -119,7 +125,10 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
|
||||||
samples = []
|
samples = []
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
|
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
|
with (
|
||||||
|
torch.no_grad(),
|
||||||
|
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
|
||||||
|
):
|
||||||
for batch in tqdm(val_loader, desc="Validation"):
|
for batch in tqdm(val_loader, desc="Validation"):
|
||||||
images = batch[cfg.training.image_key].to(device)
|
images = batch[cfg.training.image_key].to(device)
|
||||||
labels = batch[cfg.training.label_key].float().to(device)
|
labels = batch[cfg.training.label_key].float().to(device)
|
||||||
|
@ -170,7 +179,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
|
||||||
return accuracy, eval_info
|
return accuracy, eval_info
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base="1.2", config_path="../configs", config_name="hilserl_classifier")
|
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
|
||||||
def train(cfg: DictConfig) -> None:
|
def train(cfg: DictConfig) -> None:
|
||||||
# Main training pipeline with support for resuming training
|
# Main training pipeline with support for resuming training
|
||||||
logging.info(OmegaConf.to_yaml(cfg))
|
logging.info(OmegaConf.to_yaml(cfg))
|
||||||
|
|
|
@ -7720,4 +7720,4 @@ xarm = ["gym-xarm"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.10,<3.13"
|
python-versions = ">=3.10,<3.13"
|
||||||
content-hash = "b9d299916ced6af1d243f961a32b0a4aacbef18e0b95337a5224e8511f5d6dda"
|
content-hash = "44c74163e398e8ff16973957f69a47bb09b789e92ac4d8fb3ab268defab96427"
|
||||||
|
|
|
@ -71,8 +71,8 @@ pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platfo
|
||||||
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
|
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
|
||||||
pyserial = {version = ">=3.5", optional = true}
|
pyserial = {version = ">=3.5", optional = true}
|
||||||
jsonlines = ">=4.0.0"
|
jsonlines = ">=4.0.0"
|
||||||
transformers = {version = "^4.47.0", optional = true}
|
transformers = {version = ">=4.47.0", optional = true}
|
||||||
torchmetrics = {version = "^1.6.0", optional = true}
|
torchmetrics = {version = ">=1.6.0", optional = true}
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||||
Classifier,
|
|
||||||
ClassifierConfig,
|
ClassifierConfig,
|
||||||
ClassifierOutput,
|
ClassifierOutput,
|
||||||
)
|
)
|
||||||
|
@ -21,6 +20,8 @@ def test_classifier_output():
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
def test_binary_classifier_with_default_params():
|
def test_binary_classifier_with_default_params():
|
||||||
|
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||||
|
|
||||||
config = ClassifierConfig()
|
config = ClassifierConfig()
|
||||||
classifier = Classifier(config)
|
classifier = Classifier(config)
|
||||||
|
|
||||||
|
@ -40,6 +41,8 @@ def test_binary_classifier_with_default_params():
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
def test_multiclass_classifier():
|
def test_multiclass_classifier():
|
||||||
|
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||||
|
|
||||||
num_classes = 5
|
num_classes = 5
|
||||||
config = ClassifierConfig(num_classes=num_classes)
|
config = ClassifierConfig(num_classes=num_classes)
|
||||||
classifier = Classifier(config)
|
classifier = Classifier(config)
|
||||||
|
@ -60,6 +63,8 @@ def test_multiclass_classifier():
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
def test_default_device():
|
def test_default_device():
|
||||||
|
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||||
|
|
||||||
config = ClassifierConfig()
|
config = ClassifierConfig()
|
||||||
assert config.device == "cpu"
|
assert config.device == "cpu"
|
||||||
|
|
||||||
|
@ -70,6 +75,8 @@ def test_default_device():
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
def test_explicit_device_setup():
|
def test_explicit_device_setup():
|
||||||
|
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||||
|
|
||||||
config = ClassifierConfig(device="meta")
|
config = ClassifierConfig(device="meta")
|
||||||
assert config.device == "meta"
|
assert config.device == "meta"
|
||||||
|
|
||||||
|
|
|
@ -151,9 +151,9 @@ def test_validate():
|
||||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir")
|
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir")
|
||||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger")
|
@patch("lerobot.scripts.train_hilserl_classifier.Logger")
|
||||||
@patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset")
|
@patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset")
|
||||||
@patch("lerobot.scripts.train_hilserl_classifier.make_policy")
|
@patch("lerobot.scripts.train_hilserl_classifier.get_model")
|
||||||
def test_resume_function(
|
def test_resume_function(
|
||||||
mock_make_policy,
|
mock_get_model,
|
||||||
mock_dataset,
|
mock_dataset,
|
||||||
mock_logger,
|
mock_logger,
|
||||||
mock_get_last_pretrained_model_dir,
|
mock_get_last_pretrained_model_dir,
|
||||||
|
@ -168,7 +168,7 @@ def test_resume_function(
|
||||||
|
|
||||||
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
|
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
|
||||||
cfg = compose(
|
cfg = compose(
|
||||||
config_name="reward_classifier",
|
config_name="hilserl_classifier",
|
||||||
overrides=[
|
overrides=[
|
||||||
"device=cpu",
|
"device=cpu",
|
||||||
"seed=42",
|
"seed=42",
|
||||||
|
@ -211,7 +211,7 @@ def test_resume_function(
|
||||||
|
|
||||||
# Instantiate the model and set make_policy to return it
|
# Instantiate the model and set make_policy to return it
|
||||||
model = make_dummy_model()
|
model = make_dummy_model()
|
||||||
mock_make_policy.return_value = model
|
mock_get_model.return_value = model
|
||||||
|
|
||||||
# Call train
|
# Call train
|
||||||
train(cfg)
|
train(cfg)
|
||||||
|
|
Loading…
Reference in New Issue