This commit is contained in:
Adil Zouitine 2025-04-14 16:01:24 +02:00 committed by GitHub
commit 8952f5fd43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 986 additions and 755 deletions

View File

@ -171,7 +171,6 @@ class VideoRecordConfig:
class WrapperConfig:
"""Configuration for environment wrappers."""
delta_action: float | None = None
joint_masking_action_space: list[bool] | None = None
@ -191,7 +190,6 @@ class EnvWrapperConfig:
"""Configuration for environment wrappers."""
display_cameras: bool = False
delta_action: float = 0.1
use_relative_joint_positions: bool = True
add_joint_velocity_to_observation: bool = False
add_ee_pose_to_observation: bool = False
@ -203,6 +201,10 @@ class EnvWrapperConfig:
joint_masking_action_space: Optional[Any] = None
ee_action_space_params: Optional[EEActionSpaceConfig] = None
use_gripper: bool = False
gripper_quantization_threshold: float | None = 0.8
gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
open_gripper_on_reset: bool = False
@EnvConfig.register_subclass(name="gym_manipulator")
@ -254,6 +256,7 @@ class ManiskillEnvConfig(EnvConfig):
robot: str = "so100" # This is a hack to make the robot config work
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
mock_gripper: bool = False
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),

View File

@ -85,12 +85,14 @@ class SACConfig(PreTrainedConfig):
freeze_vision_encoder: Whether to freeze the vision encoder during training.
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
shared_encoder: Whether to use a shared encoder for actor and critic.
num_discrete_actions: Number of discrete actions, eg for gripper actions.
concurrency: Configuration for concurrency settings.
actor_learner: Configuration for actor-learner architecture.
online_steps: Number of steps for online training.
online_env_seed: Seed for the online environment.
online_buffer_capacity: Capacity of the online replay buffer.
offline_buffer_capacity: Capacity of the offline replay buffer.
async_prefetch: Whether to use asynchronous prefetching for the buffers.
online_step_before_learning: Number of steps before learning starts.
policy_update_freq: Frequency of policy updates.
discount: Discount factor for the SAC algorithm.
@ -118,7 +120,7 @@ class SACConfig(PreTrainedConfig):
}
)
dataset_stats: dict[str, dict[str, list[float]]] = field(
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
default_factory=lambda: {
"observation.image": {
"mean": [0.485, 0.456, 0.406],
@ -144,12 +146,14 @@ class SACConfig(PreTrainedConfig):
freeze_vision_encoder: bool = True
image_encoder_hidden_dim: int = 32
shared_encoder: bool = True
num_discrete_actions: int | None = None
# Training parameter
online_steps: int = 1000000
online_env_seed: int = 10000
online_buffer_capacity: int = 100000
offline_buffer_capacity: int = 100000
async_prefetch: bool = False
online_step_before_learning: int = 100
policy_update_freq: int = 1
@ -173,7 +177,7 @@ class SACConfig(PreTrainedConfig):
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
grasp_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)

View File

@ -33,6 +33,8 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.utils import get_device_from_parameters
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
class SACPolicy(
PreTrainedPolicy,
@ -49,6 +51,8 @@ class SACPolicy(
config.validate_features()
self.config = config
continuous_action_dim = config.output_features["action"].shape[0]
if config.dataset_stats is not None:
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
self.normalize_inputs = Normalize(
@ -59,16 +63,20 @@ class SACPolicy(
else:
self.normalize_inputs = nn.Identity()
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
if config.dataset_stats is not None:
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
# HACK: This is hacky and should be removed
dataset_stats = dataset_stats or output_normalization_params
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
# HACK: This is hacky and should be removed
dataset_stats = dataset_stats or output_normalization_params
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
else:
self.normalize_targets = nn.Identity()
self.unnormalize_outputs = nn.Identity()
# NOTE: For images the encoder should be shared between the actor and critic
if config.shared_encoder:
@ -77,11 +85,12 @@ class SACPolicy(
else:
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
self.shared_encoder = config.shared_encoder
# Create a list of critic heads
critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
input_dim=encoder_critic.output_dim + continuous_action_dim,
**asdict(config.critic_network_kwargs),
)
for _ in range(config.num_critics)
@ -96,7 +105,7 @@ class SACPolicy(
# Create target critic heads as deepcopies of the original critic heads
target_critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
input_dim=encoder_critic.output_dim + continuous_action_dim,
**asdict(config.critic_network_kwargs),
)
for _ in range(config.num_critics)
@ -112,15 +121,41 @@ class SACPolicy(
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
self.grasp_critic = None
self.grasp_critic_target = None
if config.num_discrete_actions is not None:
# Create grasp critic
self.grasp_critic = GraspCritic(
encoder=encoder_critic,
input_dim=encoder_critic.output_dim,
output_dim=config.num_discrete_actions,
**asdict(config.grasp_critic_network_kwargs),
)
# Create target grasp critic
self.grasp_critic_target = GraspCritic(
encoder=encoder_critic,
input_dim=encoder_critic.output_dim,
output_dim=config.num_discrete_actions,
**asdict(config.grasp_critic_network_kwargs),
)
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
self.grasp_critic = torch.compile(self.grasp_critic)
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
self.actor = Policy(
encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
action_dim=config.output_features["action"].shape[0],
action_dim=continuous_action_dim,
encoder_is_shared=config.shared_encoder,
**asdict(config.policy_kwargs),
)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2)
config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@ -131,11 +166,18 @@ class SACPolicy(
self.temperature = self.log_alpha.exp().item()
def get_optim_params(self) -> dict:
return {
"actor": self.actor.parameters_to_optimize,
"critic": self.critic_ensemble.parameters_to_optimize,
optim_params = {
"actor": [
p
for n, p in self.actor.named_parameters()
if not n.startswith("encoder") or not self.shared_encoder
],
"critic": self.critic_ensemble.parameters(),
"temperature": self.log_alpha,
}
if self.config.num_discrete_actions is not None:
optim_params["grasp_critic"] = self.grasp_critic.parameters_to_optimize
return optim_params
def reset(self):
"""Reset the policy"""
@ -151,8 +193,19 @@ class SACPolicy(
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
actions, _, _ = self.actor(batch)
# We cached the encoder output to avoid recomputing it
observations_features = None
if self.shared_encoder:
observations_features = self.actor.encoder.get_image_features(batch, normalize=True)
actions, _, _ = self.actor(batch, observations_features)
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.num_discrete_actions is not None:
discrete_action_value = self.grasp_critic(batch, observations_features)
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
actions = torch.cat([actions, discrete_action], dim=-1)
return actions
def critic_forward(
@ -172,14 +225,30 @@ class SACPolicy(
Returns:
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions, observation_features)
return q_values
def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor:
"""Forward pass through a grasp critic network
Args:
observations: Dictionary of observations
use_target: If True, use target critics, otherwise use ensemble critics
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
Returns:
Tensor of Q-values from the grasp critic network
"""
grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic
q_values = grasp_critic(observations, observation_features)
return q_values
def forward(
self,
batch: dict[str, Tensor | dict[str, Tensor]],
model: Literal["actor", "critic", "temperature"] = "critic",
model: Literal["actor", "critic", "temperature", "grasp_critic"] = "critic",
) -> dict[str, Tensor]:
"""Compute the loss for the given model
@ -192,12 +261,11 @@ class SACPolicy(
- done: Done mask tensor
- observation_feature: Optional pre-computed observation features
- next_observation_feature: Optional pre-computed next observation features
model: Which model to compute the loss for ("actor", "critic", or "temperature")
model: Which model to compute the loss for ("actor", "critic", "grasp_critic", or "temperature")
Returns:
The computed loss tensor
"""
# TODO: (maractingi, azouitine) Respect the function signature we output tensors
# Extract common components from batch
actions: Tensor = batch["action"]
observations: dict[str, Tensor] = batch["state"]
@ -210,7 +278,7 @@ class SACPolicy(
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
return self.compute_loss_critic(
loss_critic = self.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
@ -220,17 +288,41 @@ class SACPolicy(
next_observation_features=next_observation_features,
)
if model == "actor":
return self.compute_loss_actor(
return {"loss_critic": loss_critic}
if model == "grasp_critic" and self.config.num_discrete_actions is not None:
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
complementary_info = batch.get("complementary_info")
loss_grasp_critic = self.compute_loss_grasp_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
complementary_info=complementary_info,
)
return {"loss_grasp_critic": loss_grasp_critic}
if model == "actor":
return {
"loss_actor": self.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
}
if model == "temperature":
return self.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
return {
"loss_temperature": self.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
}
raise ValueError(f"Unknown model type: {model}")
@ -245,6 +337,16 @@ class SACPolicy(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
if self.config.num_discrete_actions is not None:
for target_param, param in zip(
self.grasp_critic_target.parameters(),
self.grasp_critic.parameters(),
strict=False,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def update_temperature(self):
self.temperature = self.log_alpha.exp().item()
@ -287,6 +389,11 @@ class SACPolicy(
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
if self.config.num_discrete_actions is not None:
# NOTE: We only want to keep the continuous action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
q_preds = self.critic_forward(
observations=observations,
actions=actions,
@ -307,6 +414,65 @@ class SACPolicy(
).sum()
return critics_loss
def compute_loss_grasp_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features=None,
next_observation_features=None,
complementary_info=None,
):
# NOTE: We only want to keep the discrete action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = torch.round(actions_discrete)
actions_discrete = actions_discrete.long()
gripper_penalties: Tensor | None = None
if complementary_info is not None:
gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty")
with torch.no_grad():
# For DQN, select actions using online network, evaluate with target network
next_grasp_qs = self.grasp_critic_forward(
next_observations, use_target=False, observation_features=next_observation_features
)
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True)
# Get target Q-values from target network
target_next_grasp_qs = self.grasp_critic_forward(
observations=next_observations,
use_target=True,
observation_features=next_observation_features,
)
# Use gather to select Q-values for best actions
target_next_grasp_q = torch.gather(
target_next_grasp_qs, dim=1, index=best_next_grasp_action
).squeeze(-1)
# Compute target Q-value with Bellman equation
rewards_gripper = rewards
if gripper_penalties is not None:
rewards_gripper = rewards + gripper_penalties
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
# Get predicted Q-values for current observations
predicted_grasp_qs = self.grasp_critic_forward(
observations=observations, use_target=False, observation_features=observation_features
)
# Use gather to select Q-values for taken actions
predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions_discrete).squeeze(-1)
# Compute MSE loss between predicted and target Q-values
grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q)
return grasp_critic_loss
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
@ -337,6 +503,104 @@ class SACPolicy(
return actor_loss
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
"""
Creates encoders for pixel and/or state modalities.
"""
super().__init__()
self.config = config
self.input_normalization = input_normalizer
self.has_pretrained_vision_encoder = False
self.aggregation_size: int = 0
if any("observation.image" in key for key in config.input_features):
self.camera_number = config.camera_number
if self.config.vision_encoder_name is not None:
self.image_enc_layers = PretrainedImageEncoder(config)
self.has_pretrained_vision_encoder = True
else:
self.image_enc_layers = DefaultImageEncoder(config)
self.aggregation_size += config.latent_dim * self.camera_number
if config.freeze_vision_encoder:
freeze_image_encoder(self.image_enc_layers.image_enc_layers)
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
if "observation.state" in config.input_features:
self.state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_features["observation.state"].shape[0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
if "observation.environment_state" in config.input_features:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_features["observation.environment_state"].shape[0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
def forward(
self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
) -> Tensor:
"""Encode the image and/or state vector.
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
over all features.
"""
feat = []
obs_dict = self.input_normalization(obs_dict)
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
vision_encoder_cache = self.get_image_features(obs_dict, normalize=False)
if vision_encoder_cache is not None:
feat.append(vision_encoder_cache)
if "observation.environment_state" in self.config.input_features:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_features:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
features = torch.cat(tensors=feat, dim=-1)
features = self.aggregation_layer(features)
return features
def get_image_features(self, batch: dict[str, Tensor], normalize: bool = True) -> torch.Tensor:
# [N*B, C, H, W]
if normalize:
batch = self.input_normalization(batch)
if len(self.all_image_keys) > 0:
# Batch all images along the batch dimension, then encode them.
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
images_batched = self.image_enc_layers(images_batched)
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
embeddings_image = torch.cat(embeddings_chunks, dim=-1)
return embeddings_image
return None
@property
def output_dim(self) -> int:
"""Returns the dimension of the encoder output"""
return self.config.latent_dim
class MLP(nn.Module):
def __init__(
self,
@ -459,7 +723,7 @@ class CriticEnsemble(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
encoder: SACObservationEncoder,
ensemble: List[CriticHead],
output_normalization: nn.Module,
init_final: Optional[float] = None,
@ -470,12 +734,6 @@ class CriticEnsemble(nn.Module):
self.output_normalization = output_normalization
self.critics = nn.ModuleList(ensemble)
self.parameters_to_optimize = []
# Handle the case where a part of the encoder if frozen
if self.encoder is not None:
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
self.parameters_to_optimize += list(self.critics.parameters())
def forward(
self,
observations: dict[str, torch.Tensor],
@ -491,11 +749,7 @@ class CriticEnsemble(nn.Module):
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
obs_enc = (
observation_features
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
obs_enc = self.encoder(observations, observation_features)
inputs = torch.cat([obs_enc, actions], dim=-1)
@ -509,10 +763,53 @@ class CriticEnsemble(nn.Module):
return q_values
class GraspCritic(nn.Module):
def __init__(
self,
encoder: nn.Module,
input_dim: int,
hidden_dims: list[int],
output_dim: int = 3,
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
init_final: Optional[float] = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.encoder = encoder
self.output_dim = output_dim
self.net = MLP(
input_dim=input_dim,
hidden_dims=hidden_dims,
activations=activations,
activate_final=activate_final,
dropout_rate=dropout_rate,
final_activation=final_activation,
)
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim)
if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.output_layer.weight)
def forward(
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device by cloning first to avoid inplace operations
observations = {k: v.to(device) for k, v in observations.items()}
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
return self.output_layer(self.net(obs_enc))
class Policy(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
encoder: SACObservationEncoder,
network: nn.Module,
action_dim: int,
log_std_min: float = -5,
@ -523,19 +820,15 @@ class Policy(nn.Module):
encoder_is_shared: bool = False,
):
super().__init__()
self.encoder = encoder
self.encoder: SACObservationEncoder = encoder
self.network = network
self.action_dim = action_dim
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.fixed_std = fixed_std
self.use_tanh_squash = use_tanh_squash
self.parameters_to_optimize = []
self.encoder_is_shared = encoder_is_shared
self.parameters_to_optimize += list(self.network.parameters())
if self.encoder is not None and not encoder_is_shared:
self.parameters_to_optimize += list(self.encoder.parameters())
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
@ -549,7 +842,6 @@ class Policy(nn.Module):
else:
orthogonal_init()(self.mean_layer.weight)
self.parameters_to_optimize += list(self.mean_layer.parameters())
# Standard deviation layer or parameter
if fixed_std is None:
self.std_layer = nn.Linear(out_features, action_dim)
@ -558,7 +850,6 @@ class Policy(nn.Module):
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.std_layer.weight)
self.parameters_to_optimize += list(self.std_layer.parameters())
def forward(
self,
@ -566,11 +857,9 @@ class Policy(nn.Module):
observation_features: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
obs_enc = (
observation_features
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
if self.encoder_is_shared:
obs_enc = obs_enc.detach()
# Get network outputs
outputs = self.network(obs_enc)
@ -614,96 +903,6 @@ class Policy(nn.Module):
return observations
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
"""
Creates encoders for pixel and/or state modalities.
"""
super().__init__()
self.config = config
self.input_normalization = input_normalizer
self.has_pretrained_vision_encoder = False
self.parameters_to_optimize = []
self.aggregation_size: int = 0
if any("observation.image" in key for key in config.input_features):
self.camera_number = config.camera_number
if self.config.vision_encoder_name is not None:
self.image_enc_layers = PretrainedImageEncoder(config)
self.has_pretrained_vision_encoder = True
else:
self.image_enc_layers = DefaultImageEncoder(config)
self.aggregation_size += config.latent_dim * self.camera_number
if config.freeze_vision_encoder:
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
if "observation.state" in config.input_features:
self.state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_features["observation.state"].shape[0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
if "observation.environment_state" in config.input_features:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_features["observation.environment_state"].shape[0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
over all features.
"""
feat = []
obs_dict = self.input_normalization(obs_dict)
# Batch all images along the batch dimension, then encode them.
if len(self.all_image_keys) > 0:
images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
images_batched = self.image_enc_layers(images_batched)
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
feat.extend(embeddings_chunks)
if "observation.environment_state" in self.config.input_features:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_features:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
features = torch.cat(tensors=feat, dim=-1)
features = self.aggregation_layer(features)
return features
@property
def output_dim(self) -> int:
"""Returns the dimension of the encoder output"""
return self.config.latent_dim
class DefaultImageEncoder(nn.Module):
def __init__(self, config: SACConfig):
super().__init__()
@ -743,23 +942,25 @@ class DefaultImageEncoder(nn.Module):
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
with torch.inference_mode():
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
self.image_enc_proj = nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
self.freeze_image_encoder = config.freeze_vision_encoder
def forward(self, x):
return self.image_enc_layers(x)
x = self.image_enc_layers(x)
if self.freeze_image_encoder:
x = x.detach()
return self.image_enc_proj(x)
class PretrainedImageEncoder(nn.Module):
def __init__(self, config: SACConfig):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
@ -767,6 +968,8 @@ class PretrainedImageEncoder(nn.Module):
nn.Tanh(),
)
self.freeze_image_encoder = config.freeze_vision_encoder
def _load_pretrained_vision_encoder(self, config: SACConfig):
"""Set up CNN encoder"""
from transformers import AutoModel
@ -786,6 +989,8 @@ class PretrainedImageEncoder(nn.Module):
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
# doesn't reach the classifier layer because we don't need it
enc_feat = self.image_enc_layers(x).pooler_output
if self.freeze_image_encoder:
enc_feat = enc_feat.detach()
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
return enc_feat

View File

@ -221,7 +221,6 @@ def record_episode(
events=events,
policy=policy,
fps=fps,
# record_delta_actions=record_delta_actions,
teleoperate=policy is None,
single_task=single_task,
)
@ -267,8 +266,6 @@ def control_loop(
if teleoperate:
observation, action = robot.teleop_step(record_data=True)
# if record_delta_actions:
# action["action"] = action["action"] - current_joint_positions
else:
observation = robot.capture_observation()

View File

@ -363,8 +363,6 @@ def replay(
start_episode_t = time.perf_counter()
action = actions[idx]["action"]
# if replay_delta_actions:
# action = action + current_joint_positions
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t

File diff suppressed because it is too large Load Diff

View File

@ -42,7 +42,6 @@ class HILSerlRobotEnv(gym.Env):
self,
robot,
use_delta_action_space: bool = True,
delta: float | None = None,
display_cameras: bool = False,
):
"""
@ -55,8 +54,6 @@ class HILSerlRobotEnv(gym.Env):
robot: The robot interface object used to connect and interact with the physical robot.
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
joint positions are used.
delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between
0 and 1 when using a delta action space.
display_cameras (bool): If True, the robot's camera feeds will be displayed during execution.
"""
super().__init__()
@ -74,7 +71,6 @@ class HILSerlRobotEnv(gym.Env):
self.current_step = 0
self.episode_data = None
self.delta = delta
self.use_delta_action_space = use_delta_action_space
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
@ -374,7 +370,7 @@ class RewardWrapper(gym.Wrapper):
self.device = device
def step(self, action):
observation, _, terminated, truncated, info = self.env.step(action)
observation, reward, terminated, truncated, info = self.env.step(action)
images = [
observation[key].to(self.device, non_blocking=self.device.type == "cuda")
for key in observation
@ -382,15 +378,17 @@ class RewardWrapper(gym.Wrapper):
]
start_time = time.perf_counter()
with torch.inference_mode():
reward = (
success = (
self.reward_classifier.predict_reward(images, threshold=0.8)
if self.reward_classifier is not None
else 0.0
)
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
if reward == 1.0:
if success == 1.0:
terminated = True
reward = 1.0
return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
@ -720,11 +718,13 @@ class ResetWrapper(gym.Wrapper):
env: HILSerlRobotEnv,
reset_pose: np.ndarray | None = None,
reset_time_s: float = 5,
open_gripper_on_reset: bool = False,
):
super().__init__(env)
self.reset_time_s = reset_time_s
self.reset_pose = reset_pose
self.robot = self.unwrapped.robot
self.open_gripper_on_reset = open_gripper_on_reset
def reset(self, *, seed=None, options=None):
if self.reset_pose is not None:
@ -733,6 +733,14 @@ class ResetWrapper(gym.Wrapper):
reset_follower_position(self.robot, self.reset_pose)
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
log_say("Reset the environment done.", play_sounds=True)
if self.open_gripper_on_reset:
current_joint_pos = self.robot.follower_arms["main"].read("Present_Position")
current_joint_pos[-1] = MAX_GRIPPER_COMMAND
self.robot.send_action(torch.from_numpy(current_joint_pos))
busy_wait(0.1)
current_joint_pos[-1] = 0.0
self.robot.send_action(torch.from_numpy(current_joint_pos))
busy_wait(0.2)
else:
log_say(
f"Manually reset the environment for {self.reset_time_s} seconds.",
@ -761,6 +769,75 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
return observation
class GripperPenaltyWrapper(gym.RewardWrapper):
def __init__(self, env, penalty: float = -0.1, gripper_penalty_in_reward: bool = True):
super().__init__(env)
self.penalty = penalty
self.gripper_penalty_in_reward = gripper_penalty_in_reward
self.last_gripper_state = None
def reward(self, reward, action):
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
action_normalized = action - 1.0 # action / MAX_GRIPPER_COMMAND
gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or (
gripper_state_normalized > 0.75 and action_normalized < -0.5
)
return reward + self.penalty * int(gripper_penalty_bool)
def step(self, action):
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
if isinstance(action, tuple):
gripper_action = action[0][-1]
else:
gripper_action = action[-1]
obs, reward, terminated, truncated, info = self.env.step(action)
gripper_penalty = self.reward(reward, gripper_action)
if self.gripper_penalty_in_reward:
reward += gripper_penalty
else:
info["gripper_penalty"] = gripper_penalty
return obs, reward, terminated, truncated, info
def reset(self, **kwargs):
self.last_gripper_state = None
obs, info = super().reset(**kwargs)
if self.gripper_penalty_in_reward:
info["gripper_penalty"] = 0.0
return obs, info
class GripperActionWrapper(gym.ActionWrapper):
def __init__(self, env, quantization_threshold: float = 0.2):
super().__init__(env)
self.quantization_threshold = quantization_threshold
def action(self, action):
is_intervention = False
if isinstance(action, tuple):
action, is_intervention = action
gripper_command = action[-1]
# Gripper actions are between 0, 2
# we want to quantize them to -1, 0 or 1
gripper_command = gripper_command - 1.0
if self.quantization_threshold is not None:
# Quantize gripper command to -1, 0 or 1
gripper_command = (
np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0
)
gripper_command = gripper_command * MAX_GRIPPER_COMMAND
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
action[-1] = gripper_action.item()
return action, is_intervention
class EEActionWrapper(gym.ActionWrapper):
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
super().__init__(env)
@ -780,10 +857,12 @@ class EEActionWrapper(gym.ActionWrapper):
]
)
if self.use_gripper:
action_space_bounds = np.concatenate([action_space_bounds, [1.0]])
# gripper actions open at 2.0, and closed at 0.0
min_action_space_bounds = np.concatenate([-action_space_bounds, [0.0]])
max_action_space_bounds = np.concatenate([action_space_bounds, [2.0]])
ee_action_space = gym.spaces.Box(
low=-action_space_bounds,
high=action_space_bounds,
low=min_action_space_bounds,
high=max_action_space_bounds,
shape=(3 + int(self.use_gripper),),
dtype=np.float32,
)
@ -820,17 +899,7 @@ class EEActionWrapper(gym.ActionWrapper):
fk_func=self.fk_function,
)
if self.use_gripper:
# Quantize gripper command to -1, 0 or 1
if gripper_command < -0.2:
gripper_command = -1.0
elif gripper_command > 0.2:
gripper_command = 1.0
else:
gripper_command = 0.0
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
target_joint_pos[-1] = gripper_action
target_joint_pos[-1] = gripper_command
return target_joint_pos, is_intervention
@ -951,11 +1020,11 @@ class GamepadControlWrapper(gym.Wrapper):
if self.use_gripper:
gripper_command = self.controller.gripper_command()
if gripper_command == "open":
gamepad_action = np.concatenate([gamepad_action, [1.0]])
gamepad_action = np.concatenate([gamepad_action, [2.0]])
elif gripper_command == "close":
gamepad_action = np.concatenate([gamepad_action, [-1.0]])
else:
gamepad_action = np.concatenate([gamepad_action, [0.0]])
else:
gamepad_action = np.concatenate([gamepad_action, [1.0]])
# Check episode ending buttons
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
@ -1095,7 +1164,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
env = HILSerlRobotEnv(
robot=robot,
display_cameras=cfg.wrapper.display_cameras,
delta=cfg.wrapper.delta_action,
use_delta_action_space=cfg.wrapper.use_relative_joint_positions
and cfg.wrapper.ee_action_space_params is None,
)
@ -1118,12 +1186,22 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
# Add reward computation and control wrappers
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.wrapper.use_gripper:
env = GripperActionWrapper(env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold)
if cfg.wrapper.gripper_penalty is not None:
env = GripperPenaltyWrapper(
env=env,
penalty=cfg.wrapper.gripper_penalty,
gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward,
)
if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper(
env=env,
ee_action_space_params=cfg.wrapper.ee_action_space_params,
use_gripper=cfg.wrapper.use_gripper,
)
if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
env = GamepadControlWrapper(
@ -1140,6 +1218,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
env=env,
reset_pose=cfg.wrapper.fixed_reset_joint_positions,
reset_time_s=cfg.wrapper.reset_time_s,
open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset,
)
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
@ -1289,11 +1368,10 @@ def record_dataset(env, policy, cfg):
dataset.push_to_hub()
def replay_episode(env, repo_id, root=None, episode=0):
def replay_episode(env, cfg):
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
local_files_only = root is not None
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode])
env.reset()
actions = dataset.hf_dataset.select_columns("action")
@ -1301,7 +1379,7 @@ def replay_episode(env, repo_id, root=None, episode=0):
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action = actions[idx]["action"][:4]
action = actions[idx]["action"]
env.step((action, False))
# env.step((action / env.unwrapped.delta, False))
@ -1332,9 +1410,7 @@ def main(cfg: EnvConfig):
if cfg.mode == "replay":
replay_episode(
env,
cfg.replay_repo_id,
root=cfg.dataset_root,
episode=cfg.replay_episode,
cfg=cfg,
)
exit()

View File

@ -1,4 +1,4 @@
#!/usr/bin/env python
# !/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
@ -269,6 +269,7 @@ def add_actor_information_and_train(
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
saving_checkpoint = cfg.save_checkpoint
online_steps = cfg.policy.online_steps
async_prefetch = cfg.policy.async_prefetch
# Initialize logging for multiprocessing
if not use_threads(cfg):
@ -326,6 +327,9 @@ def add_actor_information_and_train(
if cfg.dataset is not None:
dataset_repo_id = cfg.dataset.repo_id
# Initialize iterators
online_iterator = None
offline_iterator = None
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
while True:
# Exit the training loop if shutdown is requested
@ -359,13 +363,26 @@ def add_actor_information_and_train(
if len(replay_buffer) < online_step_before_learning:
continue
if online_iterator is None:
logging.debug("[LEARNER] Initializing online replay buffer iterator")
online_iterator = replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
if offline_replay_buffer is not None and offline_iterator is None:
logging.debug("[LEARNER] Initializing offline replay buffer iterator")
offline_iterator = offline_replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
logging.debug("[LEARNER] Starting optimization loop")
time_for_one_optimization_step = time.time()
for _ in range(utd_ratio - 1):
batch = replay_buffer.sample(batch_size=batch_size)
# Sample from the iterators
batch = next(online_iterator)
if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
batch_offline = next(offline_iterator)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
@ -390,26 +407,40 @@ def add_actor_information_and_train(
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
"complementary_info": batch["complementary_info"],
}
# Use the forward method for critic loss
loss_critic = policy.forward(forward_batch, model="critic")
# Use the forward method for critic loss (includes both main critic and grasp critic)
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad()
loss_critic.backward()
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
)
optimizers["critic"].step()
# Grasp critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
)
optimizers["grasp_critic"].step()
# Update target networks
policy.update_target_networks()
batch = replay_buffer.sample(batch_size=batch_size)
# Sample for the last update in the UTD ratio
batch = next(online_iterator)
if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
batch_offline = next(offline_iterator)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
@ -437,63 +468,80 @@ def add_actor_information_and_train(
"next_observation_feature": next_observation_features,
}
# Use the forward method for critic loss
loss_critic = policy.forward(forward_batch, model="critic")
# Use the forward method for critic loss (includes both main critic and grasp critic)
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad()
loss_critic.backward()
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["critic"].step()
training_infos = {}
training_infos["loss_critic"] = loss_critic.item()
training_infos["critic_grad_norm"] = critic_grad_norm
# Initialize training info dictionary
training_infos = {
"loss_critic": loss_critic.item(),
"critic_grad_norm": critic_grad_norm,
}
# Grasp critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
).item()
optimizers["grasp_critic"].step()
# Add grasp critic info to training info
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
# Actor and temperature optimization (at specified frequency)
if optimization_step % policy_update_freq == 0:
for _ in range(policy_update_freq):
# Use the forward method for actor loss
loss_actor = policy.forward(forward_batch, model="actor")
# Actor optimization
actor_output = policy.forward(forward_batch, model="actor")
loss_actor = actor_output["loss_actor"]
optimizers["actor"].zero_grad()
loss_actor.backward()
# clip gradients
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["actor"].step()
# Add actor info to training info
training_infos["loss_actor"] = loss_actor.item()
training_infos["actor_grad_norm"] = actor_grad_norm
# Temperature optimization using forward method
loss_temperature = policy.forward(forward_batch, model="temperature")
# Temperature optimization
temperature_output = policy.forward(forward_batch, model="temperature")
loss_temperature = temperature_output["loss_temperature"]
optimizers["temperature"].zero_grad()
loss_temperature.backward()
# clip gradients
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
).item()
optimizers["temperature"].step()
# Add temperature info to training info
training_infos["loss_temperature"] = loss_temperature.item()
training_infos["temperature_grad_norm"] = temp_grad_norm
training_infos["temperature"] = policy.temperature
# Update temperature
policy.update_temperature()
# Check if it's time to push updated policy to actors
# Push policy to actors if needed
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
last_time_policy_pushed = time.time()
# Update target networks
policy.update_target_networks()
# Log training metrics at specified intervals
@ -697,7 +745,7 @@ def save_training_checkpoint(
logging.info("Resume training")
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
"""
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
@ -725,10 +773,19 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
"""
optimizer_actor = torch.optim.Adam(
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
params=policy.actor.parameters_to_optimize,
params=[
p
for n, p in policy.actor.named_parameters()
if not n.startswith("encoder") or not policy.config.shared_encoder
],
lr=cfg.policy.actor_lr,
)
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
if cfg.policy.num_discrete_actions is not None:
optimizer_grasp_critic = torch.optim.Adam(
params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None
optimizers = {
@ -736,6 +793,8 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
"critic": optimizer_critic,
"temperature": optimizer_temperature,
}
if cfg.policy.num_discrete_actions is not None:
optimizers["grasp_critic"] = optimizer_grasp_critic
return optimizers, lr_scheduler
@ -936,7 +995,6 @@ def initialize_offline_replay_buffer(
device=device,
state_keys=cfg.policy.input_features.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device,
optimize_memory=True,
capacity=cfg.policy.offline_buffer_capacity,
@ -970,12 +1028,8 @@ def get_observation_features(
return None, None
with torch.no_grad():
observation_features = (
policy.actor.encoder(observations) if policy.actor.encoder is not None else None
)
next_observation_features = (
policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None
)
observation_features = policy.actor.encoder.get_image_features(observations, normalize=True)
next_observation_features = policy.actor.encoder.get_image_features(next_observations, normalize=True)
return observation_features, next_observation_features
@ -1037,6 +1091,44 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
parameters_queue.put(state_bytes)
def check_weight_gradients(module: nn.Module) -> dict[str, bool]:
"""
Checks whether each parameter in the module has a gradient.
Args:
module (nn.Module): A PyTorch module whose parameters will be inspected.
Returns:
dict[str, bool]: A dictionary where each key is the parameter name and the value is
True if the parameter has an associated gradient (i.e. .grad is not None),
otherwise False.
"""
grad_status = {}
for name, param in module.named_parameters():
grad_status[name] = param.grad is not None
return grad_status
def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]:
"""
Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary.
Args:
actor (nn.Module): The actor model.
grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate
whether each parameter has a gradient.
Returns:
dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status.
"""
# Get actor parameter names as a set.
model_param_names = {name for name, _ in model.named_parameters()}
# Intersect parameter names between actor and grad_status.
overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names}
return overlapping
def process_interaction_message(
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
):

View File

@ -1,5 +1,3 @@
import logging
import time
from typing import Any
import einops
@ -10,7 +8,6 @@ from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from lerobot.common.envs.configs import ManiskillEnvConfig
from lerobot.configs import parser
def preprocess_maniskill_observation(
@ -153,6 +150,27 @@ class TimeLimitWrapper(gym.Wrapper):
return super().reset(seed=seed, options=options)
class ManiskillMockGripperWrapper(gym.Wrapper):
def __init__(self, env, nb_discrete_actions: int = 3):
super().__init__(env)
new_shape = env.action_space[0].shape[0] + 1
new_low = np.concatenate([env.action_space[0].low, [0]])
new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]])
action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,))
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
def step(self, action):
if isinstance(action, tuple):
action_agent, telop_action = action
else:
telop_action = 0
action_agent = action
real_action = action_agent[:-1]
final_action = (real_action, telop_action)
obs, reward, terminated, truncated, info = self.env.step(final_action)
return obs, reward, terminated, truncated, info
def make_maniskill(
cfg: ManiskillEnvConfig,
n_envs: int | None = None,
@ -197,40 +215,42 @@ def make_maniskill(
env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
if cfg.mock_gripper:
env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3)
return env
@parser.wrap()
def main(cfg: ManiskillEnvConfig):
"""Main function to run the ManiSkill environment."""
# Create the ManiSkill environment
env = make_maniskill(cfg, n_envs=1)
# @parser.wrap()
# def main(cfg: TrainPipelineConfig):
# """Main function to run the ManiSkill environment."""
# # Create the ManiSkill environment
# env = make_maniskill(cfg.env, n_envs=1)
# Reset the environment
obs, info = env.reset()
# # Reset the environment
# obs, info = env.reset()
# Run a simple interaction loop
sum_reward = 0
for i in range(100):
# Sample a random action
action = env.action_space.sample()
# # Run a simple interaction loop
# sum_reward = 0
# for i in range(100):
# # Sample a random action
# action = env.action_space.sample()
# Step the environment
start_time = time.perf_counter()
obs, reward, terminated, truncated, info = env.step(action)
step_time = time.perf_counter() - start_time
sum_reward += reward
# Log information
# # Step the environment
# start_time = time.perf_counter()
# obs, reward, terminated, truncated, info = env.step(action)
# step_time = time.perf_counter() - start_time
# sum_reward += reward
# # Log information
# Reset if episode terminated
if terminated or truncated:
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
sum_reward = 0
obs, info = env.reset()
# # Reset if episode terminated
# if terminated or truncated:
# logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
# sum_reward = 0
# obs, info = env.reset()
# Close the environment
env.close()
# # Close the environment
# env.close()
# if __name__ == "__main__":