Handle new config with sac
This commit is contained in:
parent
b2025b852c
commit
f483931fc0
|
@ -55,10 +55,9 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||||
|
|
||||||
return PI0Policy
|
return PI0Policy
|
||||||
elif name == "sac":
|
elif name == "sac":
|
||||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
|
||||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||||
|
|
||||||
return SACPolicy, SACConfig
|
return SACPolicy
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||||
|
|
||||||
|
|
|
@ -79,28 +79,38 @@ def create_stats_buffers(
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||||
if stats:
|
if stats and key in stats:
|
||||||
if isinstance(stats[key]["mean"], np.ndarray):
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
if norm_mode is NormalizationMode.MEAN_STD:
|
if "mean" not in stats[key] or "std" not in stats[key]:
|
||||||
|
raise ValueError(f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization")
|
||||||
|
|
||||||
|
if isinstance(stats[key]["mean"], np.ndarray):
|
||||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
# unnormalization). See the logic here
|
||||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
|
||||||
# unnormalization). See the logic here
|
|
||||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
|
||||||
if norm_mode is NormalizationMode.MEAN_STD:
|
|
||||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
else:
|
||||||
|
type_ = type(stats[key]["mean"])
|
||||||
|
raise ValueError(f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead.")
|
||||||
|
|
||||||
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
|
if "min" not in stats[key] or "max" not in stats[key]:
|
||||||
|
raise ValueError(f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization")
|
||||||
|
|
||||||
|
if isinstance(stats[key]["min"], np.ndarray):
|
||||||
|
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||||
|
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||||
|
elif isinstance(stats[key]["min"], torch.Tensor):
|
||||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
type_ = type(stats[key]["mean"])
|
type_ = type(stats[key]["min"])
|
||||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
raise ValueError(f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead.")
|
||||||
|
|
||||||
stats_buffers[key] = buffer
|
stats_buffers[key] = buffer
|
||||||
return stats_buffers
|
return stats_buffers
|
||||||
|
@ -155,6 +165,7 @@ class Normalize(nn.Module):
|
||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
if key not in batch:
|
if key not in batch:
|
||||||
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
||||||
|
# NOTE: (azouitine) This continues help us for instantiation SACPolicy
|
||||||
continue
|
continue
|
||||||
|
|
||||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||||
|
|
|
@ -18,57 +18,82 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from lerobot.common.optim.optimizers import MultiAdamConfig
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.types import NormalizationMode
|
||||||
|
|
||||||
|
|
||||||
|
@PreTrainedConfig.register_subclass("sac")
|
||||||
@dataclass
|
@dataclass
|
||||||
class SACConfig:
|
class SACConfig(PreTrainedConfig):
|
||||||
input_shapes: dict[str, list[int]] = field(
|
"""Configuration class for Soft Actor-Critic (SAC) policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy.
|
||||||
|
normalization_mapping: Mapping from feature types to normalization modes.
|
||||||
|
camera_number: Number of cameras to use.
|
||||||
|
storage_device: Device to use for storage.
|
||||||
|
vision_encoder_name: Name of the vision encoder to use.
|
||||||
|
freeze_vision_encoder: Whether to freeze the vision encoder.
|
||||||
|
image_encoder_hidden_dim: Hidden dimension for the image encoder.
|
||||||
|
shared_encoder: Whether to use a shared encoder.
|
||||||
|
discount: Discount factor for the RL algorithm.
|
||||||
|
temperature_init: Initial temperature for entropy regularization.
|
||||||
|
num_critics: Number of critic networks.
|
||||||
|
num_subsample_critics: Number of critics to subsample.
|
||||||
|
critic_lr: Learning rate for critic networks.
|
||||||
|
actor_lr: Learning rate for actor network.
|
||||||
|
temperature_lr: Learning rate for temperature parameter.
|
||||||
|
critic_target_update_weight: Weight for soft target updates.
|
||||||
|
utd_ratio: Update-to-data ratio (>1 to enable).
|
||||||
|
state_encoder_hidden_dim: Hidden dimension for state encoder.
|
||||||
|
latent_dim: Dimension of latent representation.
|
||||||
|
target_entropy: Target entropy for automatic temperature tuning.
|
||||||
|
use_backup_entropy: Whether to use backup entropy.
|
||||||
|
grad_clip_norm: Gradient clipping norm.
|
||||||
|
critic_network_kwargs: Additional arguments for critic networks.
|
||||||
|
actor_network_kwargs: Additional arguments for actor network.
|
||||||
|
policy_kwargs: Additional arguments for policy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Input / output structure
|
||||||
|
n_obs_steps: int = 1
|
||||||
|
|
||||||
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"observation.image": [3, 84, 84],
|
"VISUAL": NormalizationMode.MEAN_STD,
|
||||||
"observation.state": [4],
|
"STATE": NormalizationMode.MIN_MAX,
|
||||||
|
"ENV": NormalizationMode.MIN_MAX,
|
||||||
|
"ACTION": NormalizationMode.MIN_MAX,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
output_shapes: dict[str, list[int]] = field(
|
dataset_stats: dict[str, dict[str, list[float]]] = field(
|
||||||
default_factory=lambda: {
|
|
||||||
"action": [2],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
input_normalization_modes: dict[str, str] = field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"observation.image": "mean_std",
|
|
||||||
"observation.state": "min_max",
|
|
||||||
"observation.environment_state": "min_max",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
input_normalization_params: dict[str, dict[str, list[float]]] = field(
|
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"observation.image": {
|
"observation.image": {
|
||||||
"mean": [[0.485, 0.456, 0.406]],
|
"mean": [0.485, 0.456, 0.406],
|
||||||
"std": [[0.229, 0.224, 0.225]],
|
"std": [0.229, 0.224, 0.225],
|
||||||
|
},
|
||||||
|
"observation.state": {
|
||||||
|
"min": [0.0, 0.0],
|
||||||
|
"max": [1.0, 1.0],
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"min": [0.0, 0.0, 0.0],
|
||||||
|
"max": [1.0, 1.0, 1.0],
|
||||||
},
|
},
|
||||||
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
|
||||||
output_normalization_params: dict[str, dict[str, list[float]]] = field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"action": {"min": [-1, -1], "max": [1, 1]},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
# TODO: Move it outside of the config
|
|
||||||
actor_learner_config: dict[str, str | int] = field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"learner_host": "127.0.0.1",
|
|
||||||
"learner_port": 50051,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
camera_number: int = 1
|
|
||||||
|
|
||||||
|
# Architecture specifics
|
||||||
|
camera_number: int = 1
|
||||||
storage_device: str = "cpu"
|
storage_device: str = "cpu"
|
||||||
# Add type annotations for these fields:
|
# Set to "helper2424/resnet10" for hil serl
|
||||||
vision_encoder_name: str | None = field(default="helper2424/resnet10")
|
vision_encoder_name: str | None = None
|
||||||
freeze_vision_encoder: bool = True
|
freeze_vision_encoder: bool = True
|
||||||
image_encoder_hidden_dim: int = 32
|
image_encoder_hidden_dim: int = 32
|
||||||
shared_encoder: bool = True
|
shared_encoder: bool = True
|
||||||
|
|
||||||
|
# SAC algorithm parameters
|
||||||
discount: float = 0.99
|
discount: float = 0.99
|
||||||
temperature_init: float = 1.0
|
temperature_init: float = 1.0
|
||||||
num_critics: int = 2
|
num_critics: int = 2
|
||||||
|
@ -83,6 +108,8 @@ class SACConfig:
|
||||||
target_entropy: float | None = None
|
target_entropy: float | None = None
|
||||||
use_backup_entropy: bool = True
|
use_backup_entropy: bool = True
|
||||||
grad_clip_norm: float = 40.0
|
grad_clip_norm: float = 40.0
|
||||||
|
|
||||||
|
# Network configuration
|
||||||
critic_network_kwargs: dict[str, Any] = field(
|
critic_network_kwargs: dict[str, Any] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"hidden_dims": [256, 256],
|
"hidden_dims": [256, 256],
|
||||||
|
@ -104,3 +131,52 @@ class SACConfig:
|
||||||
"init_final": 0.05,
|
"init_final": 0.05,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Deprecated, kept for backward compatibility
|
||||||
|
actor_learner_config: dict[str, str | int] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"learner_host": "127.0.0.1",
|
||||||
|
"learner_port": 50051,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
# Any validation specific to SAC configuration
|
||||||
|
|
||||||
|
def get_optimizer_preset(self) -> MultiAdamConfig:
|
||||||
|
return MultiAdamConfig(
|
||||||
|
weight_decay=0.0,
|
||||||
|
optimizer_groups={
|
||||||
|
"actor": {"lr": self.actor_lr},
|
||||||
|
"critic": {"lr": self.critic_lr},
|
||||||
|
"temperature": {"lr": self.temperature_lr},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_scheduler_preset(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def validate_features(self) -> None:
|
||||||
|
# TODO: Maybe we should remove this raise?
|
||||||
|
if len(self.image_features) == 0:
|
||||||
|
raise ValueError("You must provide at least one image among the inputs.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_delta_indices(self) -> list:
|
||||||
|
return list(range(1 - self.n_obs_steps, 1))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_delta_indices(self) -> list:
|
||||||
|
return [0] # SAC typically predicts one action at a time
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward_delta_indices(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import draccus
|
||||||
|
config = SACConfig()
|
||||||
|
draccus.set_config_type("json")
|
||||||
|
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )
|
||||||
|
|
||||||
|
|
|
@ -29,18 +29,17 @@ import torch.nn.functional as F # noqa: N812
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||||
from lerobot.common.policies.utils import get_device_from_parameters
|
from lerobot.common.policies.utils import get_device_from_parameters
|
||||||
|
|
||||||
|
|
||||||
class SACPolicy(
|
class SACPolicy(
|
||||||
nn.Module,
|
PreTrainedPolicy,
|
||||||
PyTorchModelHubMixin,
|
|
||||||
library_name="lerobot",
|
|
||||||
repo_url="https://github.com/huggingface/lerobot",
|
|
||||||
tags=["robotics", "RL", "SAC"],
|
|
||||||
):
|
):
|
||||||
|
|
||||||
|
config_class = SACConfig
|
||||||
name = "sac"
|
name = "sac"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -48,35 +47,33 @@ class SACPolicy(
|
||||||
config: SACConfig | None = None,
|
config: SACConfig | None = None,
|
||||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(config)
|
||||||
|
config.validate_features()
|
||||||
if config is None:
|
|
||||||
config = SACConfig()
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
if config.input_normalization_modes is not None:
|
if config.dataset_stats is not None:
|
||||||
input_normalization_params = _convert_normalization_params_to_tensor(
|
input_normalization_params = _convert_normalization_params_to_tensor(
|
||||||
config.input_normalization_params
|
config.dataset_stats
|
||||||
)
|
)
|
||||||
self.normalize_inputs = Normalize(
|
self.normalize_inputs = Normalize(
|
||||||
config.input_shapes,
|
config.input_features,
|
||||||
config.input_normalization_modes,
|
config.normalization_mapping,
|
||||||
input_normalization_params,
|
input_normalization_params,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.normalize_inputs = nn.Identity()
|
self.normalize_inputs = nn.Identity()
|
||||||
|
|
||||||
output_normalization_params = _convert_normalization_params_to_tensor(
|
output_normalization_params = _convert_normalization_params_to_tensor(
|
||||||
config.output_normalization_params
|
config.dataset_stats
|
||||||
)
|
)
|
||||||
|
|
||||||
# HACK: This is hacky and should be removed
|
# HACK: This is hacky and should be removed
|
||||||
dataset_stats = dataset_stats or output_normalization_params
|
dataset_stats = dataset_stats or output_normalization_params
|
||||||
self.normalize_targets = Normalize(
|
self.normalize_targets = Normalize(
|
||||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
config.output_features, config.normalization_mapping, dataset_stats
|
||||||
)
|
)
|
||||||
self.unnormalize_outputs = Unnormalize(
|
self.unnormalize_outputs = Unnormalize(
|
||||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
config.output_features, config.normalization_mapping, dataset_stats
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: For images the encoder should be shared between the actor and critic
|
# NOTE: For images the encoder should be shared between the actor and critic
|
||||||
|
@ -90,7 +87,7 @@ class SACPolicy(
|
||||||
# Create a list of critic heads
|
# Create a list of critic heads
|
||||||
critic_heads = [
|
critic_heads = [
|
||||||
CriticHead(
|
CriticHead(
|
||||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
|
||||||
**config.critic_network_kwargs,
|
**config.critic_network_kwargs,
|
||||||
)
|
)
|
||||||
for _ in range(config.num_critics)
|
for _ in range(config.num_critics)
|
||||||
|
@ -105,7 +102,7 @@ class SACPolicy(
|
||||||
# Create target critic heads as deepcopies of the original critic heads
|
# Create target critic heads as deepcopies of the original critic heads
|
||||||
target_critic_heads = [
|
target_critic_heads = [
|
||||||
CriticHead(
|
CriticHead(
|
||||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
|
||||||
**config.critic_network_kwargs,
|
**config.critic_network_kwargs,
|
||||||
)
|
)
|
||||||
for _ in range(config.num_critics)
|
for _ in range(config.num_critics)
|
||||||
|
@ -125,12 +122,12 @@ class SACPolicy(
|
||||||
self.actor = Policy(
|
self.actor = Policy(
|
||||||
encoder=encoder_actor,
|
encoder=encoder_actor,
|
||||||
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
||||||
action_dim=config.output_shapes["action"][0],
|
action_dim=config.output_features["action"].shape[0],
|
||||||
encoder_is_shared=config.shared_encoder,
|
encoder_is_shared=config.shared_encoder,
|
||||||
**config.policy_kwargs,
|
**config.policy_kwargs,
|
||||||
)
|
)
|
||||||
if config.target_entropy is None:
|
if config.target_entropy is None:
|
||||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2)
|
||||||
|
|
||||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||||
|
@ -140,102 +137,13 @@ class SACPolicy(
|
||||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
|
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
|
||||||
self.temperature = self.log_alpha.exp().item()
|
self.temperature = self.log_alpha.exp().item()
|
||||||
|
|
||||||
def _save_pretrained(self, save_directory):
|
|
||||||
"""Custom save method to handle TensorDict properly"""
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from dataclasses import asdict
|
|
||||||
|
|
||||||
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
|
def get_optim_params(self) -> dict:
|
||||||
from safetensors.torch import save_model
|
return {
|
||||||
|
"actor": self.actor.parameters_to_optimize,
|
||||||
save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE))
|
"critic": self.critic_ensemble.parameters_to_optimize,
|
||||||
|
"temperature": self.log_alpha,
|
||||||
# Save config
|
}
|
||||||
config_dict = asdict(self.config)
|
|
||||||
with open(os.path.join(save_directory, CONFIG_NAME), "w") as f:
|
|
||||||
json.dump(config_dict, f, indent=2)
|
|
||||||
print(f"Saved config to {os.path.join(save_directory, CONFIG_NAME)}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _from_pretrained(
|
|
||||||
cls,
|
|
||||||
*,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str],
|
|
||||||
cache_dir: Optional[Union[str, Path]],
|
|
||||||
force_download: bool,
|
|
||||||
proxies: Optional[Dict],
|
|
||||||
resume_download: Optional[bool],
|
|
||||||
local_files_only: bool,
|
|
||||||
token: Optional[Union[str, bool]],
|
|
||||||
map_location: str = "cpu",
|
|
||||||
strict: bool = False,
|
|
||||||
**model_kwargs,
|
|
||||||
) -> "SACPolicy":
|
|
||||||
"""Custom load method to handle loading SAC policy from saved files"""
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
|
|
||||||
from safetensors.torch import load_model
|
|
||||||
|
|
||||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
|
||||||
|
|
||||||
# Check if model_id is a local path or a hub model ID
|
|
||||||
if os.path.isdir(model_id):
|
|
||||||
model_path = Path(model_id)
|
|
||||||
safetensors_file = os.path.join(model_path, SAFETENSORS_SINGLE_FILE)
|
|
||||||
config_file = os.path.join(model_path, CONFIG_NAME)
|
|
||||||
else:
|
|
||||||
# Download the safetensors file from the hub
|
|
||||||
safetensors_file = hf_hub_download(
|
|
||||||
repo_id=model_id,
|
|
||||||
filename=SAFETENSORS_SINGLE_FILE,
|
|
||||||
revision=revision,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
resume_download=resume_download,
|
|
||||||
token=token,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
)
|
|
||||||
# Download the config file
|
|
||||||
try:
|
|
||||||
config_file = hf_hub_download(
|
|
||||||
repo_id=model_id,
|
|
||||||
filename=CONFIG_NAME,
|
|
||||||
revision=revision,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
resume_download=resume_download,
|
|
||||||
token=token,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
config_file = None
|
|
||||||
|
|
||||||
# Load or create config
|
|
||||||
if config_file and os.path.exists(config_file):
|
|
||||||
# Load config from file
|
|
||||||
with open(config_file) as f:
|
|
||||||
config_dict = json.load(f)
|
|
||||||
config = SACConfig(**config_dict)
|
|
||||||
else:
|
|
||||||
# Use the provided config or create a default one
|
|
||||||
config = model_kwargs.get("config", SACConfig())
|
|
||||||
|
|
||||||
# Create a new instance with the loaded config
|
|
||||||
model = cls(config=config)
|
|
||||||
|
|
||||||
# Load state dict from safetensors file
|
|
||||||
if os.path.exists(safetensors_file):
|
|
||||||
load_model(model, filename=safetensors_file, device=map_location)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the policy"""
|
"""Reset the policy"""
|
||||||
|
@ -667,7 +575,7 @@ class SACObservationEncoder(nn.Module):
|
||||||
self.parameters_to_optimize = []
|
self.parameters_to_optimize = []
|
||||||
|
|
||||||
self.aggregation_size: int = 0
|
self.aggregation_size: int = 0
|
||||||
if any("observation.image" in key for key in config.input_shapes):
|
if any("observation.image" in key for key in config.input_features):
|
||||||
self.camera_number = config.camera_number
|
self.camera_number = config.camera_number
|
||||||
|
|
||||||
if self.config.vision_encoder_name is not None:
|
if self.config.vision_encoder_name is not None:
|
||||||
|
@ -682,12 +590,12 @@ class SACObservationEncoder(nn.Module):
|
||||||
freeze_image_encoder(self.image_enc_layers)
|
freeze_image_encoder(self.image_enc_layers)
|
||||||
else:
|
else:
|
||||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||||
self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
||||||
|
|
||||||
if "observation.state" in config.input_shapes:
|
if "observation.state" in config.input_features:
|
||||||
self.state_enc_layers = nn.Sequential(
|
self.state_enc_layers = nn.Sequential(
|
||||||
nn.Linear(
|
nn.Linear(
|
||||||
in_features=config.input_shapes["observation.state"][0],
|
in_features=config.input_features["observation.state"].shape[0],
|
||||||
out_features=config.latent_dim,
|
out_features=config.latent_dim,
|
||||||
),
|
),
|
||||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||||
|
@ -697,10 +605,10 @@ class SACObservationEncoder(nn.Module):
|
||||||
|
|
||||||
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
||||||
|
|
||||||
if "observation.environment_state" in config.input_shapes:
|
if "observation.environment_state" in config.input_features:
|
||||||
self.env_state_enc_layers = nn.Sequential(
|
self.env_state_enc_layers = nn.Sequential(
|
||||||
nn.Linear(
|
nn.Linear(
|
||||||
in_features=config.input_shapes["observation.environment_state"][0],
|
in_features=config.input_features["observation.environment_state"].shape[0],
|
||||||
out_features=config.latent_dim,
|
out_features=config.latent_dim,
|
||||||
),
|
),
|
||||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||||
|
@ -727,9 +635,9 @@ class SACObservationEncoder(nn.Module):
|
||||||
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
||||||
feat.extend(embeddings_chunks)
|
feat.extend(embeddings_chunks)
|
||||||
|
|
||||||
if "observation.environment_state" in self.config.input_shapes:
|
if "observation.environment_state" in self.config.input_features:
|
||||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||||
if "observation.state" in self.config.input_shapes:
|
if "observation.state" in self.config.input_features:
|
||||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||||
|
|
||||||
features = torch.cat(tensors=feat, dim=-1)
|
features = torch.cat(tensors=feat, dim=-1)
|
||||||
|
@ -744,11 +652,11 @@ class SACObservationEncoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class DefaultImageEncoder(nn.Module):
|
class DefaultImageEncoder(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config: SACConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.image_enc_layers = nn.Sequential(
|
self.image_enc_layers = nn.Sequential(
|
||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
in_channels=config.input_shapes["observation.image"][0],
|
in_channels=config.input_features["observation.image"].shape[0],
|
||||||
out_channels=config.image_encoder_hidden_dim,
|
out_channels=config.image_encoder_hidden_dim,
|
||||||
kernel_size=7,
|
kernel_size=7,
|
||||||
stride=2,
|
stride=2,
|
||||||
|
@ -776,7 +684,7 @@ class DefaultImageEncoder(nn.Module):
|
||||||
),
|
),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
)
|
)
|
||||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
dummy_batch = torch.zeros(1, *config.input_features["observation.image"].shape)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||||
self.image_enc_layers.extend(
|
self.image_enc_layers.extend(
|
||||||
|
@ -793,7 +701,7 @@ class DefaultImageEncoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class PretrainedImageEncoder(nn.Module):
|
class PretrainedImageEncoder(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config: SACConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||||
|
@ -803,7 +711,7 @@ class PretrainedImageEncoder(nn.Module):
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_pretrained_vision_encoder(self, config):
|
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
||||||
"""Set up CNN encoder"""
|
"""Set up CNN encoder"""
|
||||||
from transformers import AutoModel
|
from transformers import AutoModel
|
||||||
|
|
||||||
|
@ -857,73 +765,73 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Benchmark the CriticEnsemble performance
|
# # Benchmark the CriticEnsemble performance
|
||||||
import time
|
# import time
|
||||||
|
|
||||||
# Configuration
|
# # Configuration
|
||||||
num_critics = 10
|
# num_critics = 10
|
||||||
batch_size = 32
|
# batch_size = 32
|
||||||
action_dim = 7
|
# action_dim = 7
|
||||||
obs_dim = 64
|
# obs_dim = 64
|
||||||
hidden_dims = [256, 256]
|
# hidden_dims = [256, 256]
|
||||||
num_iterations = 100
|
# num_iterations = 100
|
||||||
|
|
||||||
print("Creating test environment...")
|
# print("Creating test environment...")
|
||||||
|
|
||||||
# Create a simple dummy encoder
|
# # Create a simple dummy encoder
|
||||||
class DummyEncoder(nn.Module):
|
# class DummyEncoder(nn.Module):
|
||||||
def __init__(self):
|
# def __init__(self):
|
||||||
super().__init__()
|
# super().__init__()
|
||||||
self.output_dim = obs_dim
|
# self.output_dim = obs_dim
|
||||||
self.parameters_to_optimize = []
|
# self.parameters_to_optimize = []
|
||||||
|
|
||||||
def forward(self, obs):
|
# def forward(self, obs):
|
||||||
# Just return a random tensor of the right shape
|
# # Just return a random tensor of the right shape
|
||||||
# In practice, this would encode the observations
|
# # In practice, this would encode the observations
|
||||||
return torch.randn(batch_size, obs_dim, device=device)
|
# return torch.randn(batch_size, obs_dim, device=device)
|
||||||
|
|
||||||
# Create critic heads
|
# # Create critic heads
|
||||||
print(f"Creating {num_critics} critic heads...")
|
# print(f"Creating {num_critics} critic heads...")
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
critic_heads = [
|
# critic_heads = [
|
||||||
CriticHead(
|
# CriticHead(
|
||||||
input_dim=obs_dim + action_dim,
|
# input_dim=obs_dim + action_dim,
|
||||||
hidden_dims=hidden_dims,
|
# hidden_dims=hidden_dims,
|
||||||
).to(device)
|
# ).to(device)
|
||||||
for _ in range(num_critics)
|
# for _ in range(num_critics)
|
||||||
]
|
# ]
|
||||||
|
|
||||||
# Create the critic ensemble
|
# # Create the critic ensemble
|
||||||
print("Creating CriticEnsemble...")
|
# print("Creating CriticEnsemble...")
|
||||||
critic_ensemble = CriticEnsemble(
|
# critic_ensemble = CriticEnsemble(
|
||||||
encoder=DummyEncoder().to(device),
|
# encoder=DummyEncoder().to(device),
|
||||||
ensemble=critic_heads,
|
# ensemble=critic_heads,
|
||||||
output_normalization=nn.Identity(),
|
# output_normalization=nn.Identity(),
|
||||||
).to(device)
|
# ).to(device)
|
||||||
|
|
||||||
# Create random input data
|
# # Create random input data
|
||||||
print("Creating input data...")
|
# print("Creating input data...")
|
||||||
obs_dict = {
|
# obs_dict = {
|
||||||
"observation.state": torch.randn(batch_size, obs_dim, device=device),
|
# "observation.state": torch.randn(batch_size, obs_dim, device=device),
|
||||||
}
|
# }
|
||||||
actions = torch.randn(batch_size, action_dim, device=device)
|
# actions = torch.randn(batch_size, action_dim, device=device)
|
||||||
|
|
||||||
# Warmup run
|
# # Warmup run
|
||||||
print("Warming up...")
|
# print("Warming up...")
|
||||||
_ = critic_ensemble(obs_dict, actions)
|
# _ = critic_ensemble(obs_dict, actions)
|
||||||
|
|
||||||
# Time the forward pass
|
# # Time the forward pass
|
||||||
print(f"Running benchmark with {num_iterations} iterations...")
|
# print(f"Running benchmark with {num_iterations} iterations...")
|
||||||
start_time = time.perf_counter()
|
# start_time = time.perf_counter()
|
||||||
for _ in range(num_iterations):
|
# for _ in range(num_iterations):
|
||||||
q_values = critic_ensemble(obs_dict, actions)
|
# q_values = critic_ensemble(obs_dict, actions)
|
||||||
end_time = time.perf_counter()
|
# end_time = time.perf_counter()
|
||||||
|
|
||||||
# Print results
|
# # Print results
|
||||||
elapsed_time = end_time - start_time
|
# elapsed_time = end_time - start_time
|
||||||
print(f"Total time: {elapsed_time:.4f} seconds")
|
# print(f"Total time: {elapsed_time:.4f} seconds")
|
||||||
print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
|
# print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
|
||||||
print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
|
# print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
|
||||||
|
|
||||||
# Verify that all critic heads produce different outputs
|
# Verify that all critic heads produce different outputs
|
||||||
# This confirms each critic head is unique
|
# This confirms each critic head is unique
|
||||||
|
@ -932,3 +840,11 @@ if __name__ == "__main__":
|
||||||
# for j in range(i + 1, num_critics):
|
# for j in range(i + 1, num_critics):
|
||||||
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
|
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
|
||||||
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
|
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
|
||||||
|
import draccus
|
||||||
|
|
||||||
|
from lerobot.configs import parser
|
||||||
|
@parser.wrap()
|
||||||
|
def main(config: SACConfig):
|
||||||
|
policy = SACPolicy(config=config)
|
||||||
|
print("yolo")
|
||||||
|
main()
|
Loading…
Reference in New Issue