Handle new config with sac
This commit is contained in:
parent
b2025b852c
commit
f483931fc0
|
@ -55,10 +55,9 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
|||
|
||||
return PI0Policy
|
||||
elif name == "sac":
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
return SACPolicy, SACConfig
|
||||
return SACPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
|
|
|
@ -79,28 +79,38 @@ def create_stats_buffers(
|
|||
)
|
||||
|
||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||
if stats:
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if stats and key in stats:
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if "mean" not in stats[key] or "std" not in stats[key]:
|
||||
raise ValueError(f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization")
|
||||
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead.")
|
||||
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
if "min" not in stats[key] or "max" not in stats[key]:
|
||||
raise ValueError(f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization")
|
||||
|
||||
if isinstance(stats[key]["min"], np.ndarray):
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
elif isinstance(stats[key]["min"], torch.Tensor):
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
else:
|
||||
type_ = type(stats[key]["min"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead.")
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
@ -155,6 +165,7 @@ class Normalize(nn.Module):
|
|||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
||||
# NOTE: (azouitine) This continues help us for instantiation SACPolicy
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
|
|
|
@ -18,57 +18,82 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from lerobot.common.optim.optimizers import MultiAdamConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sac")
|
||||
@dataclass
|
||||
class SACConfig:
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
class SACConfig(PreTrainedConfig):
|
||||
"""Configuration class for Soft Actor-Critic (SAC) policy.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy.
|
||||
normalization_mapping: Mapping from feature types to normalization modes.
|
||||
camera_number: Number of cameras to use.
|
||||
storage_device: Device to use for storage.
|
||||
vision_encoder_name: Name of the vision encoder to use.
|
||||
freeze_vision_encoder: Whether to freeze the vision encoder.
|
||||
image_encoder_hidden_dim: Hidden dimension for the image encoder.
|
||||
shared_encoder: Whether to use a shared encoder.
|
||||
discount: Discount factor for the RL algorithm.
|
||||
temperature_init: Initial temperature for entropy regularization.
|
||||
num_critics: Number of critic networks.
|
||||
num_subsample_critics: Number of critics to subsample.
|
||||
critic_lr: Learning rate for critic networks.
|
||||
actor_lr: Learning rate for actor network.
|
||||
temperature_lr: Learning rate for temperature parameter.
|
||||
critic_target_update_weight: Weight for soft target updates.
|
||||
utd_ratio: Update-to-data ratio (>1 to enable).
|
||||
state_encoder_hidden_dim: Hidden dimension for state encoder.
|
||||
latent_dim: Dimension of latent representation.
|
||||
target_entropy: Target entropy for automatic temperature tuning.
|
||||
use_backup_entropy: Whether to use backup entropy.
|
||||
grad_clip_norm: Gradient clipping norm.
|
||||
critic_network_kwargs: Additional arguments for critic networks.
|
||||
actor_network_kwargs: Additional arguments for actor network.
|
||||
policy_kwargs: Additional arguments for policy.
|
||||
"""
|
||||
|
||||
# Input / output structure
|
||||
n_obs_steps: int = 1
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 84, 84],
|
||||
"observation.state": [4],
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ENV": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
}
|
||||
)
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
"observation.environment_state": "min_max",
|
||||
}
|
||||
)
|
||||
input_normalization_params: dict[str, dict[str, list[float]]] = field(
|
||||
dataset_stats: dict[str, dict[str, list[float]]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": {
|
||||
"mean": [[0.485, 0.456, 0.406]],
|
||||
"std": [[0.229, 0.224, 0.225]],
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
"observation.state": {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
output_normalization_params: dict[str, dict[str, list[float]]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": {"min": [-1, -1], "max": [1, 1]},
|
||||
}
|
||||
)
|
||||
# TODO: Move it outside of the config
|
||||
actor_learner_config: dict[str, str | int] = field(
|
||||
default_factory=lambda: {
|
||||
"learner_host": "127.0.0.1",
|
||||
"learner_port": 50051,
|
||||
}
|
||||
)
|
||||
camera_number: int = 1
|
||||
|
||||
# Architecture specifics
|
||||
camera_number: int = 1
|
||||
storage_device: str = "cpu"
|
||||
# Add type annotations for these fields:
|
||||
vision_encoder_name: str | None = field(default="helper2424/resnet10")
|
||||
# Set to "helper2424/resnet10" for hil serl
|
||||
vision_encoder_name: str | None = None
|
||||
freeze_vision_encoder: bool = True
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = True
|
||||
|
||||
# SAC algorithm parameters
|
||||
discount: float = 0.99
|
||||
temperature_init: float = 1.0
|
||||
num_critics: int = 2
|
||||
|
@ -83,6 +108,8 @@ class SACConfig:
|
|||
target_entropy: float | None = None
|
||||
use_backup_entropy: bool = True
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
# Network configuration
|
||||
critic_network_kwargs: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"hidden_dims": [256, 256],
|
||||
|
@ -104,3 +131,52 @@ class SACConfig:
|
|||
"init_final": 0.05,
|
||||
}
|
||||
)
|
||||
|
||||
# Deprecated, kept for backward compatibility
|
||||
actor_learner_config: dict[str, str | int] = field(
|
||||
default_factory=lambda: {
|
||||
"learner_host": "127.0.0.1",
|
||||
"learner_port": 50051,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Any validation specific to SAC configuration
|
||||
|
||||
def get_optimizer_preset(self) -> MultiAdamConfig:
|
||||
return MultiAdamConfig(
|
||||
weight_decay=0.0,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": self.actor_lr},
|
||||
"critic": {"lr": self.critic_lr},
|
||||
"temperature": {"lr": self.temperature_lr},
|
||||
},
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# TODO: Maybe we should remove this raise?
|
||||
if len(self.image_features) == 0:
|
||||
raise ValueError("You must provide at least one image among the inputs.")
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return [0] # SAC typically predicts one action at a time
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
if __name__ == "__main__":
|
||||
import draccus
|
||||
config = SACConfig()
|
||||
draccus.set_config_type("json")
|
||||
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )
|
||||
|
||||
|
|
|
@ -29,18 +29,17 @@ import torch.nn.functional as F # noqa: N812
|
|||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
|
||||
|
||||
class SACPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "RL", "SAC"],
|
||||
PreTrainedPolicy,
|
||||
):
|
||||
|
||||
config_class = SACConfig
|
||||
name = "sac"
|
||||
|
||||
def __init__(
|
||||
|
@ -48,35 +47,33 @@ class SACPolicy(
|
|||
config: SACConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = SACConfig()
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
if config.input_normalization_modes is not None:
|
||||
if config.dataset_stats is not None:
|
||||
input_normalization_params = _convert_normalization_params_to_tensor(
|
||||
config.input_normalization_params
|
||||
config.dataset_stats
|
||||
)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes,
|
||||
config.input_normalization_modes,
|
||||
config.input_features,
|
||||
config.normalization_mapping,
|
||||
input_normalization_params,
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
|
||||
output_normalization_params = _convert_normalization_params_to_tensor(
|
||||
config.output_normalization_params
|
||||
config.dataset_stats
|
||||
)
|
||||
|
||||
# HACK: This is hacky and should be removed
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# NOTE: For images the encoder should be shared between the actor and critic
|
||||
|
@ -90,7 +87,7 @@ class SACPolicy(
|
|||
# Create a list of critic heads
|
||||
critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
|
@ -105,7 +102,7 @@ class SACPolicy(
|
|||
# Create target critic heads as deepcopies of the original critic heads
|
||||
target_critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
|
@ -125,12 +122,12 @@ class SACPolicy(
|
|||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
action_dim=config.output_features["action"].shape[0],
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
||||
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2)
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||
|
@ -140,102 +137,13 @@ class SACPolicy(
|
|||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
def _save_pretrained(self, save_directory):
|
||||
"""Custom save method to handle TensorDict properly"""
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
|
||||
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
|
||||
from safetensors.torch import save_model
|
||||
|
||||
save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE))
|
||||
|
||||
# Save config
|
||||
config_dict = asdict(self.config)
|
||||
with open(os.path.join(save_directory, CONFIG_NAME), "w") as f:
|
||||
json.dump(config_dict, f, indent=2)
|
||||
print(f"Saved config to {os.path.join(save_directory, CONFIG_NAME)}")
|
||||
|
||||
@classmethod
|
||||
def _from_pretrained(
|
||||
cls,
|
||||
*,
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
cache_dir: Optional[Union[str, Path]],
|
||||
force_download: bool,
|
||||
proxies: Optional[Dict],
|
||||
resume_download: Optional[bool],
|
||||
local_files_only: bool,
|
||||
token: Optional[Union[str, bool]],
|
||||
map_location: str = "cpu",
|
||||
strict: bool = False,
|
||||
**model_kwargs,
|
||||
) -> "SACPolicy":
|
||||
"""Custom load method to handle loading SAC policy from saved files"""
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
|
||||
from safetensors.torch import load_model
|
||||
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
|
||||
# Check if model_id is a local path or a hub model ID
|
||||
if os.path.isdir(model_id):
|
||||
model_path = Path(model_id)
|
||||
safetensors_file = os.path.join(model_path, SAFETENSORS_SINGLE_FILE)
|
||||
config_file = os.path.join(model_path, CONFIG_NAME)
|
||||
else:
|
||||
# Download the safetensors file from the hub
|
||||
safetensors_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=SAFETENSORS_SINGLE_FILE,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
# Download the config file
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except Exception:
|
||||
config_file = None
|
||||
|
||||
# Load or create config
|
||||
if config_file and os.path.exists(config_file):
|
||||
# Load config from file
|
||||
with open(config_file) as f:
|
||||
config_dict = json.load(f)
|
||||
config = SACConfig(**config_dict)
|
||||
else:
|
||||
# Use the provided config or create a default one
|
||||
config = model_kwargs.get("config", SACConfig())
|
||||
|
||||
# Create a new instance with the loaded config
|
||||
model = cls(config=config)
|
||||
|
||||
# Load state dict from safetensors file
|
||||
if os.path.exists(safetensors_file):
|
||||
load_model(model, filename=safetensors_file, device=map_location)
|
||||
|
||||
return model
|
||||
def get_optim_params(self) -> dict:
|
||||
return {
|
||||
"actor": self.actor.parameters_to_optimize,
|
||||
"critic": self.critic_ensemble.parameters_to_optimize,
|
||||
"temperature": self.log_alpha,
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""Reset the policy"""
|
||||
|
@ -667,7 +575,7 @@ class SACObservationEncoder(nn.Module):
|
|||
self.parameters_to_optimize = []
|
||||
|
||||
self.aggregation_size: int = 0
|
||||
if any("observation.image" in key for key in config.input_shapes):
|
||||
if any("observation.image" in key for key in config.input_features):
|
||||
self.camera_number = config.camera_number
|
||||
|
||||
if self.config.vision_encoder_name is not None:
|
||||
|
@ -682,12 +590,12 @@ class SACObservationEncoder(nn.Module):
|
|||
freeze_image_encoder(self.image_enc_layers)
|
||||
else:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
||||
|
||||
if "observation.state" in config.input_shapes:
|
||||
if "observation.state" in config.input_features:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_shapes["observation.state"][0],
|
||||
in_features=config.input_features["observation.state"].shape[0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
|
@ -697,10 +605,10 @@ class SACObservationEncoder(nn.Module):
|
|||
|
||||
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
||||
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
if "observation.environment_state" in config.input_features:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_shapes["observation.environment_state"][0],
|
||||
in_features=config.input_features["observation.environment_state"].shape[0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
|
@ -727,9 +635,9 @@ class SACObservationEncoder(nn.Module):
|
|||
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
||||
feat.extend(embeddings_chunks)
|
||||
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
if "observation.environment_state" in self.config.input_features:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
if "observation.state" in self.config.input_features:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
|
||||
features = torch.cat(tensors=feat, dim=-1)
|
||||
|
@ -744,11 +652,11 @@ class SACObservationEncoder(nn.Module):
|
|||
|
||||
|
||||
class DefaultImageEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: SACConfig):
|
||||
super().__init__()
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=config.input_shapes["observation.image"][0],
|
||||
in_channels=config.input_features["observation.image"].shape[0],
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
|
@ -776,7 +684,7 @@ class DefaultImageEncoder(nn.Module):
|
|||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
dummy_batch = torch.zeros(1, *config.input_features["observation.image"].shape)
|
||||
with torch.inference_mode():
|
||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
|
@ -793,7 +701,7 @@ class DefaultImageEncoder(nn.Module):
|
|||
|
||||
|
||||
class PretrainedImageEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: SACConfig):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
|
@ -803,7 +711,7 @@ class PretrainedImageEncoder(nn.Module):
|
|||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def _load_pretrained_vision_encoder(self, config):
|
||||
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
|
@ -857,73 +765,73 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Benchmark the CriticEnsemble performance
|
||||
import time
|
||||
# # Benchmark the CriticEnsemble performance
|
||||
# import time
|
||||
|
||||
# Configuration
|
||||
num_critics = 10
|
||||
batch_size = 32
|
||||
action_dim = 7
|
||||
obs_dim = 64
|
||||
hidden_dims = [256, 256]
|
||||
num_iterations = 100
|
||||
# # Configuration
|
||||
# num_critics = 10
|
||||
# batch_size = 32
|
||||
# action_dim = 7
|
||||
# obs_dim = 64
|
||||
# hidden_dims = [256, 256]
|
||||
# num_iterations = 100
|
||||
|
||||
print("Creating test environment...")
|
||||
# print("Creating test environment...")
|
||||
|
||||
# Create a simple dummy encoder
|
||||
class DummyEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.output_dim = obs_dim
|
||||
self.parameters_to_optimize = []
|
||||
# # Create a simple dummy encoder
|
||||
# class DummyEncoder(nn.Module):
|
||||
# def __init__(self):
|
||||
# super().__init__()
|
||||
# self.output_dim = obs_dim
|
||||
# self.parameters_to_optimize = []
|
||||
|
||||
def forward(self, obs):
|
||||
# Just return a random tensor of the right shape
|
||||
# In practice, this would encode the observations
|
||||
return torch.randn(batch_size, obs_dim, device=device)
|
||||
# def forward(self, obs):
|
||||
# # Just return a random tensor of the right shape
|
||||
# # In practice, this would encode the observations
|
||||
# return torch.randn(batch_size, obs_dim, device=device)
|
||||
|
||||
# Create critic heads
|
||||
print(f"Creating {num_critics} critic heads...")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=obs_dim + action_dim,
|
||||
hidden_dims=hidden_dims,
|
||||
).to(device)
|
||||
for _ in range(num_critics)
|
||||
]
|
||||
# # Create critic heads
|
||||
# print(f"Creating {num_critics} critic heads...")
|
||||
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# critic_heads = [
|
||||
# CriticHead(
|
||||
# input_dim=obs_dim + action_dim,
|
||||
# hidden_dims=hidden_dims,
|
||||
# ).to(device)
|
||||
# for _ in range(num_critics)
|
||||
# ]
|
||||
|
||||
# Create the critic ensemble
|
||||
print("Creating CriticEnsemble...")
|
||||
critic_ensemble = CriticEnsemble(
|
||||
encoder=DummyEncoder().to(device),
|
||||
ensemble=critic_heads,
|
||||
output_normalization=nn.Identity(),
|
||||
).to(device)
|
||||
# # Create the critic ensemble
|
||||
# print("Creating CriticEnsemble...")
|
||||
# critic_ensemble = CriticEnsemble(
|
||||
# encoder=DummyEncoder().to(device),
|
||||
# ensemble=critic_heads,
|
||||
# output_normalization=nn.Identity(),
|
||||
# ).to(device)
|
||||
|
||||
# Create random input data
|
||||
print("Creating input data...")
|
||||
obs_dict = {
|
||||
"observation.state": torch.randn(batch_size, obs_dim, device=device),
|
||||
}
|
||||
actions = torch.randn(batch_size, action_dim, device=device)
|
||||
# # Create random input data
|
||||
# print("Creating input data...")
|
||||
# obs_dict = {
|
||||
# "observation.state": torch.randn(batch_size, obs_dim, device=device),
|
||||
# }
|
||||
# actions = torch.randn(batch_size, action_dim, device=device)
|
||||
|
||||
# Warmup run
|
||||
print("Warming up...")
|
||||
_ = critic_ensemble(obs_dict, actions)
|
||||
# # Warmup run
|
||||
# print("Warming up...")
|
||||
# _ = critic_ensemble(obs_dict, actions)
|
||||
|
||||
# Time the forward pass
|
||||
print(f"Running benchmark with {num_iterations} iterations...")
|
||||
start_time = time.perf_counter()
|
||||
for _ in range(num_iterations):
|
||||
q_values = critic_ensemble(obs_dict, actions)
|
||||
end_time = time.perf_counter()
|
||||
# # Time the forward pass
|
||||
# print(f"Running benchmark with {num_iterations} iterations...")
|
||||
# start_time = time.perf_counter()
|
||||
# for _ in range(num_iterations):
|
||||
# q_values = critic_ensemble(obs_dict, actions)
|
||||
# end_time = time.perf_counter()
|
||||
|
||||
# Print results
|
||||
elapsed_time = end_time - start_time
|
||||
print(f"Total time: {elapsed_time:.4f} seconds")
|
||||
print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
|
||||
print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
|
||||
# # Print results
|
||||
# elapsed_time = end_time - start_time
|
||||
# print(f"Total time: {elapsed_time:.4f} seconds")
|
||||
# print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
|
||||
# print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
|
||||
|
||||
# Verify that all critic heads produce different outputs
|
||||
# This confirms each critic head is unique
|
||||
|
@ -932,3 +840,11 @@ if __name__ == "__main__":
|
|||
# for j in range(i + 1, num_critics):
|
||||
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
|
||||
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
|
||||
import draccus
|
||||
|
||||
from lerobot.configs import parser
|
||||
@parser.wrap()
|
||||
def main(config: SACConfig):
|
||||
policy = SACPolicy(config=config)
|
||||
print("yolo")
|
||||
main()
|
Loading…
Reference in New Issue