This commit is contained in:
Adil Zouitine 2025-04-08 10:50:07 +02:00 committed by GitHub
commit a72146e814
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 871 additions and 684 deletions

View File

@ -203,6 +203,9 @@ class EnvWrapperConfig:
joint_masking_action_space: Optional[Any] = None joint_masking_action_space: Optional[Any] = None
ee_action_space_params: Optional[EEActionSpaceConfig] = None ee_action_space_params: Optional[EEActionSpaceConfig] = None
use_gripper: bool = False use_gripper: bool = False
gripper_quantization_threshold: float = 0.8
gripper_penalty: float = 0.0
open_gripper_on_reset: bool = False
@EnvConfig.register_subclass(name="gym_manipulator") @EnvConfig.register_subclass(name="gym_manipulator")
@ -254,6 +257,7 @@ class ManiskillEnvConfig(EnvConfig):
robot: str = "so100" # This is a hack to make the robot config work robot: str = "so100" # This is a hack to make the robot config work
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig) video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
wrapper: WrapperConfig = field(default_factory=WrapperConfig) wrapper: WrapperConfig = field(default_factory=WrapperConfig)
mock_gripper: bool = False
features: dict[str, PolicyFeature] = field( features: dict[str, PolicyFeature] = field(
default_factory=lambda: { default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),

View File

@ -85,12 +85,14 @@ class SACConfig(PreTrainedConfig):
freeze_vision_encoder: Whether to freeze the vision encoder during training. freeze_vision_encoder: Whether to freeze the vision encoder during training.
image_encoder_hidden_dim: Hidden dimension size for the image encoder. image_encoder_hidden_dim: Hidden dimension size for the image encoder.
shared_encoder: Whether to use a shared encoder for actor and critic. shared_encoder: Whether to use a shared encoder for actor and critic.
num_discrete_actions: Number of discrete actions, eg for gripper actions.
concurrency: Configuration for concurrency settings. concurrency: Configuration for concurrency settings.
actor_learner: Configuration for actor-learner architecture. actor_learner: Configuration for actor-learner architecture.
online_steps: Number of steps for online training. online_steps: Number of steps for online training.
online_env_seed: Seed for the online environment. online_env_seed: Seed for the online environment.
online_buffer_capacity: Capacity of the online replay buffer. online_buffer_capacity: Capacity of the online replay buffer.
offline_buffer_capacity: Capacity of the offline replay buffer. offline_buffer_capacity: Capacity of the offline replay buffer.
async_prefetch: Whether to use asynchronous prefetching for the buffers.
online_step_before_learning: Number of steps before learning starts. online_step_before_learning: Number of steps before learning starts.
policy_update_freq: Frequency of policy updates. policy_update_freq: Frequency of policy updates.
discount: Discount factor for the SAC algorithm. discount: Discount factor for the SAC algorithm.
@ -144,12 +146,14 @@ class SACConfig(PreTrainedConfig):
freeze_vision_encoder: bool = True freeze_vision_encoder: bool = True
image_encoder_hidden_dim: int = 32 image_encoder_hidden_dim: int = 32
shared_encoder: bool = True shared_encoder: bool = True
num_discrete_actions: int | None = None
# Training parameter # Training parameter
online_steps: int = 1000000 online_steps: int = 1000000
online_env_seed: int = 10000 online_env_seed: int = 10000
online_buffer_capacity: int = 100000 online_buffer_capacity: int = 100000
offline_buffer_capacity: int = 100000 offline_buffer_capacity: int = 100000
async_prefetch: bool = False
online_step_before_learning: int = 100 online_step_before_learning: int = 100
policy_update_freq: int = 1 policy_update_freq: int = 1
@ -173,7 +177,7 @@ class SACConfig(PreTrainedConfig):
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig) actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig) policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
grasp_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig) actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig) concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)

View File

@ -33,6 +33,8 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.policies.utils import get_device_from_parameters
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
class SACPolicy( class SACPolicy(
PreTrainedPolicy, PreTrainedPolicy,
@ -49,6 +51,8 @@ class SACPolicy(
config.validate_features() config.validate_features()
self.config = config self.config = config
continuous_action_dim = config.output_features["action"].shape[0]
if config.dataset_stats is not None: if config.dataset_stats is not None:
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(
@ -77,11 +81,12 @@ class SACPolicy(
else: else:
encoder_critic = SACObservationEncoder(config, self.normalize_inputs) encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor = SACObservationEncoder(config, self.normalize_inputs) encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
self.shared_encoder = config.shared_encoder
# Create a list of critic heads # Create a list of critic heads
critic_heads = [ critic_heads = [
CriticHead( CriticHead(
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0], input_dim=encoder_critic.output_dim + continuous_action_dim,
**asdict(config.critic_network_kwargs), **asdict(config.critic_network_kwargs),
) )
for _ in range(config.num_critics) for _ in range(config.num_critics)
@ -96,7 +101,7 @@ class SACPolicy(
# Create target critic heads as deepcopies of the original critic heads # Create target critic heads as deepcopies of the original critic heads
target_critic_heads = [ target_critic_heads = [
CriticHead( CriticHead(
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0], input_dim=encoder_critic.output_dim + continuous_action_dim,
**asdict(config.critic_network_kwargs), **asdict(config.critic_network_kwargs),
) )
for _ in range(config.num_critics) for _ in range(config.num_critics)
@ -112,15 +117,41 @@ class SACPolicy(
self.critic_ensemble = torch.compile(self.critic_ensemble) self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target) self.critic_target = torch.compile(self.critic_target)
self.grasp_critic = None
self.grasp_critic_target = None
if config.num_discrete_actions is not None:
# Create grasp critic
self.grasp_critic = GraspCritic(
encoder=encoder_critic,
input_dim=encoder_critic.output_dim,
output_dim=config.num_discrete_actions,
**asdict(config.grasp_critic_network_kwargs),
)
# Create target grasp critic
self.grasp_critic_target = GraspCritic(
encoder=encoder_critic,
input_dim=encoder_critic.output_dim,
output_dim=config.num_discrete_actions,
**asdict(config.grasp_critic_network_kwargs),
)
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
self.grasp_critic = torch.compile(self.grasp_critic)
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
self.actor = Policy( self.actor = Policy(
encoder=encoder_actor, encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)), network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
action_dim=config.output_features["action"].shape[0], action_dim=continuous_action_dim,
encoder_is_shared=config.shared_encoder, encoder_is_shared=config.shared_encoder,
**asdict(config.policy_kwargs), **asdict(config.policy_kwargs),
) )
if config.target_entropy is None: if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2) config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@ -131,11 +162,14 @@ class SACPolicy(
self.temperature = self.log_alpha.exp().item() self.temperature = self.log_alpha.exp().item()
def get_optim_params(self) -> dict: def get_optim_params(self) -> dict:
return { optim_params = {
"actor": self.actor.parameters_to_optimize, "actor": self.actor.parameters_to_optimize,
"critic": self.critic_ensemble.parameters_to_optimize, "critic": self.critic_ensemble.parameters_to_optimize,
"temperature": self.log_alpha, "temperature": self.log_alpha,
} }
if self.config.num_discrete_actions is not None:
optim_params["grasp_critic"] = self.grasp_critic.parameters_to_optimize
return optim_params
def reset(self): def reset(self):
"""Reset the policy""" """Reset the policy"""
@ -151,8 +185,19 @@ class SACPolicy(
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation""" """Select action for inference/evaluation"""
actions, _, _ = self.actor(batch) # We cached the encoder output to avoid recomputing it
observations_features = None
if self.shared_encoder:
observations_features = self.actor.encoder.get_image_features(batch)
actions, _, _ = self.actor(batch, observations_features)
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.num_discrete_actions is not None:
discrete_action_value = self.grasp_critic(batch, observations_features)
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
actions = torch.cat([actions, discrete_action], dim=-1)
return actions return actions
def critic_forward( def critic_forward(
@ -172,14 +217,30 @@ class SACPolicy(
Returns: Returns:
Tensor of Q-values from all critics Tensor of Q-values from all critics
""" """
critics = self.critic_target if use_target else self.critic_ensemble critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions, observation_features) q_values = critics(observations, actions, observation_features)
return q_values return q_values
def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor:
"""Forward pass through a grasp critic network
Args:
observations: Dictionary of observations
use_target: If True, use target critics, otherwise use ensemble critics
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
Returns:
Tensor of Q-values from the grasp critic network
"""
grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic
q_values = grasp_critic(observations, observation_features)
return q_values
def forward( def forward(
self, self,
batch: dict[str, Tensor | dict[str, Tensor]], batch: dict[str, Tensor | dict[str, Tensor]],
model: Literal["actor", "critic", "temperature"] = "critic", model: Literal["actor", "critic", "temperature", "grasp_critic"] = "critic",
) -> dict[str, Tensor]: ) -> dict[str, Tensor]:
"""Compute the loss for the given model """Compute the loss for the given model
@ -192,12 +253,11 @@ class SACPolicy(
- done: Done mask tensor - done: Done mask tensor
- observation_feature: Optional pre-computed observation features - observation_feature: Optional pre-computed observation features
- next_observation_feature: Optional pre-computed next observation features - next_observation_feature: Optional pre-computed next observation features
model: Which model to compute the loss for ("actor", "critic", or "temperature") model: Which model to compute the loss for ("actor", "critic", "grasp_critic", or "temperature")
Returns: Returns:
The computed loss tensor The computed loss tensor
""" """
# TODO: (maractingi, azouitine) Respect the function signature we output tensors
# Extract common components from batch # Extract common components from batch
actions: Tensor = batch["action"] actions: Tensor = batch["action"]
observations: dict[str, Tensor] = batch["state"] observations: dict[str, Tensor] = batch["state"]
@ -210,7 +270,7 @@ class SACPolicy(
done: Tensor = batch["done"] done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature") next_observation_features: Tensor = batch.get("next_observation_feature")
return self.compute_loss_critic( loss_critic = self.compute_loss_critic(
observations=observations, observations=observations,
actions=actions, actions=actions,
rewards=rewards, rewards=rewards,
@ -220,17 +280,41 @@ class SACPolicy(
next_observation_features=next_observation_features, next_observation_features=next_observation_features,
) )
return {"loss_critic": loss_critic}
if model == "grasp_critic" and self.config.num_discrete_actions is not None:
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
complementary_info = batch.get("complementary_info")
loss_grasp_critic = self.compute_loss_grasp_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
complementary_info=complementary_info,
)
return {"loss_grasp_critic": loss_grasp_critic}
if model == "actor": if model == "actor":
return self.compute_loss_actor( return {
"loss_actor": self.compute_loss_actor(
observations=observations, observations=observations,
observation_features=observation_features, observation_features=observation_features,
) )
}
if model == "temperature": if model == "temperature":
return self.compute_loss_temperature( return {
"loss_temperature": self.compute_loss_temperature(
observations=observations, observations=observations,
observation_features=observation_features, observation_features=observation_features,
) )
}
raise ValueError(f"Unknown model type: {model}") raise ValueError(f"Unknown model type: {model}")
@ -245,6 +329,16 @@ class SACPolicy(
param.data * self.config.critic_target_update_weight param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight) + target_param.data * (1.0 - self.config.critic_target_update_weight)
) )
if self.config.num_discrete_actions is not None:
for target_param, param in zip(
self.grasp_critic_target.parameters(),
self.grasp_critic.parameters(),
strict=False,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def update_temperature(self): def update_temperature(self):
self.temperature = self.log_alpha.exp().item() self.temperature = self.log_alpha.exp().item()
@ -287,6 +381,11 @@ class SACPolicy(
td_target = rewards + (1 - done) * self.config.discount * min_q td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs # 3- compute predicted qs
if self.config.num_discrete_actions is not None:
# NOTE: We only want to keep the continuous action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
q_preds = self.critic_forward( q_preds = self.critic_forward(
observations=observations, observations=observations,
actions=actions, actions=actions,
@ -307,6 +406,64 @@ class SACPolicy(
).sum() ).sum()
return critics_loss return critics_loss
def compute_loss_grasp_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features=None,
next_observation_features=None,
complementary_info=None,
):
# NOTE: We only want to keep the discrete action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = torch.round(actions_discrete)
actions_discrete = actions_discrete.long()
if complementary_info is not None:
gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty")
with torch.no_grad():
# For DQN, select actions using online network, evaluate with target network
next_grasp_qs = self.grasp_critic_forward(
next_observations, use_target=False, observation_features=next_observation_features
)
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True)
# Get target Q-values from target network
target_next_grasp_qs = self.grasp_critic_forward(
observations=next_observations,
use_target=True,
observation_features=next_observation_features,
)
# Use gather to select Q-values for best actions
target_next_grasp_q = torch.gather(
target_next_grasp_qs, dim=1, index=best_next_grasp_action
).squeeze(-1)
# Compute target Q-value with Bellman equation
rewards_gripper = rewards
if gripper_penalties is not None:
rewards_gripper = rewards + gripper_penalties
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
# Get predicted Q-values for current observations
predicted_grasp_qs = self.grasp_critic_forward(
observations=observations, use_target=False, observation_features=observation_features
)
# Use gather to select Q-values for taken actions
predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions_discrete).squeeze(-1)
# Compute MSE loss between predicted and target Q-values
grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q)
return grasp_critic_loss
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
"""Compute the temperature loss""" """Compute the temperature loss"""
# calculate temperature loss # calculate temperature loss
@ -337,6 +494,109 @@ class SACPolicy(
return actor_loss return actor_loss
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
"""
Creates encoders for pixel and/or state modalities.
"""
super().__init__()
self.config = config
self.input_normalization = input_normalizer
self.has_pretrained_vision_encoder = False
self.parameters_to_optimize = []
self.aggregation_size: int = 0
if any("observation.image" in key for key in config.input_features):
self.camera_number = config.camera_number
if self.config.vision_encoder_name is not None:
self.image_enc_layers = PretrainedImageEncoder(config)
self.has_pretrained_vision_encoder = True
else:
self.image_enc_layers = DefaultImageEncoder(config)
self.aggregation_size += config.latent_dim * self.camera_number
if config.freeze_vision_encoder:
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
if "observation.state" in config.input_features:
self.state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_features["observation.state"].shape[0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
if "observation.environment_state" in config.input_features:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_features["observation.environment_state"].shape[0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward(
self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
) -> Tensor:
"""Encode the image and/or state vector.
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
over all features.
"""
feat = []
obs_dict = self.input_normalization(obs_dict)
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
vision_encoder_cache = self.get_image_features(obs_dict)
feat.append(vision_encoder_cache)
if vision_encoder_cache is not None:
feat.append(vision_encoder_cache)
if "observation.environment_state" in self.config.input_features:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_features:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
features = torch.cat(tensors=feat, dim=-1)
features = self.aggregation_layer(features)
return features
def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor:
# [N*B, C, H, W]
if len(self.all_image_keys) > 0:
# Batch all images along the batch dimension, then encode them.
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
images_batched = self.image_enc_layers(images_batched)
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
embeddings_image = torch.cat(embeddings_chunks, dim=-1)
return embeddings_image
return None
@property
def output_dim(self) -> int:
"""Returns the dimension of the encoder output"""
return self.config.latent_dim
class MLP(nn.Module): class MLP(nn.Module):
def __init__( def __init__(
self, self,
@ -459,7 +719,7 @@ class CriticEnsemble(nn.Module):
def __init__( def __init__(
self, self,
encoder: Optional[nn.Module], encoder: SACObservationEncoder,
ensemble: List[CriticHead], ensemble: List[CriticHead],
output_normalization: nn.Module, output_normalization: nn.Module,
init_final: Optional[float] = None, init_final: Optional[float] = None,
@ -491,11 +751,7 @@ class CriticEnsemble(nn.Module):
actions = self.output_normalization(actions)["action"] actions = self.output_normalization(actions)["action"]
actions = actions.to(device) actions = actions.to(device)
obs_enc = ( obs_enc = self.encoder(observations, observation_features)
observation_features
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
inputs = torch.cat([obs_enc, actions], dim=-1) inputs = torch.cat([obs_enc, actions], dim=-1)
@ -509,10 +765,57 @@ class CriticEnsemble(nn.Module):
return q_values return q_values
class GraspCritic(nn.Module):
def __init__(
self,
encoder: nn.Module,
input_dim: int,
hidden_dims: list[int],
output_dim: int = 3,
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
init_final: Optional[float] = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.encoder = encoder
self.output_dim = output_dim
self.net = MLP(
input_dim=input_dim,
hidden_dims=hidden_dims,
activations=activations,
activate_final=activate_final,
dropout_rate=dropout_rate,
final_activation=final_activation,
)
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim)
if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.output_layer.weight)
self.parameters_to_optimize = []
self.parameters_to_optimize += list(self.net.parameters())
self.parameters_to_optimize += list(self.output_layer.parameters())
def forward(
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device by cloning first to avoid inplace operations
observations = {k: v.to(device) for k, v in observations.items()}
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
return self.output_layer(self.net(obs_enc))
class Policy(nn.Module): class Policy(nn.Module):
def __init__( def __init__(
self, self,
encoder: Optional[nn.Module], encoder: SACObservationEncoder,
network: nn.Module, network: nn.Module,
action_dim: int, action_dim: int,
log_std_min: float = -5, log_std_min: float = -5,
@ -523,7 +826,7 @@ class Policy(nn.Module):
encoder_is_shared: bool = False, encoder_is_shared: bool = False,
): ):
super().__init__() super().__init__()
self.encoder = encoder self.encoder: SACObservationEncoder = encoder
self.network = network self.network = network
self.action_dim = action_dim self.action_dim = action_dim
self.log_std_min = log_std_min self.log_std_min = log_std_min
@ -566,11 +869,7 @@ class Policy(nn.Module):
observation_features: torch.Tensor | None = None, observation_features: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists # Encode observations if encoder exists
obs_enc = ( obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
observation_features
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
# Get network outputs # Get network outputs
outputs = self.network(obs_enc) outputs = self.network(obs_enc)
@ -614,96 +913,6 @@ class Policy(nn.Module):
return observations return observations
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
"""
Creates encoders for pixel and/or state modalities.
"""
super().__init__()
self.config = config
self.input_normalization = input_normalizer
self.has_pretrained_vision_encoder = False
self.parameters_to_optimize = []
self.aggregation_size: int = 0
if any("observation.image" in key for key in config.input_features):
self.camera_number = config.camera_number
if self.config.vision_encoder_name is not None:
self.image_enc_layers = PretrainedImageEncoder(config)
self.has_pretrained_vision_encoder = True
else:
self.image_enc_layers = DefaultImageEncoder(config)
self.aggregation_size += config.latent_dim * self.camera_number
if config.freeze_vision_encoder:
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
if "observation.state" in config.input_features:
self.state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_features["observation.state"].shape[0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
if "observation.environment_state" in config.input_features:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_features["observation.environment_state"].shape[0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
over all features.
"""
feat = []
obs_dict = self.input_normalization(obs_dict)
# Batch all images along the batch dimension, then encode them.
if len(self.all_image_keys) > 0:
images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
images_batched = self.image_enc_layers(images_batched)
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
feat.extend(embeddings_chunks)
if "observation.environment_state" in self.config.input_features:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_features:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
features = torch.cat(tensors=feat, dim=-1)
features = self.aggregation_layer(features)
return features
@property
def output_dim(self) -> int:
"""Returns the dimension of the encoder output"""
return self.config.latent_dim
class DefaultImageEncoder(nn.Module): class DefaultImageEncoder(nn.Module):
def __init__(self, config: SACConfig): def __init__(self, config: SACConfig):
super().__init__() super().__init__()

View File

@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
import functools import functools
import io import io
import os
import pickle import pickle
from typing import Any, Callable, Optional, Sequence, TypedDict from typing import Any, Callable, Optional, Sequence, TypedDict
@ -33,7 +32,7 @@ class Transition(TypedDict):
next_state: dict[str, torch.Tensor] next_state: dict[str, torch.Tensor]
done: bool done: bool
truncated: bool truncated: bool
complementary_info: dict[str, Any] = None complementary_info: dict[str, torch.Tensor | float | int] | None = None
class BatchTransition(TypedDict): class BatchTransition(TypedDict):
@ -43,41 +42,47 @@ class BatchTransition(TypedDict):
next_state: dict[str, torch.Tensor] next_state: dict[str, torch.Tensor]
done: torch.Tensor done: torch.Tensor
truncated: torch.Tensor truncated: torch.Tensor
complementary_info: dict[str, torch.Tensor | float | int] | None = None
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition: def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
# Move state tensors to CPU
device = torch.device(device) device = torch.device(device)
non_blocking = device.type == "cuda"
# Move state tensors to device
transition["state"] = { transition["state"] = {
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items() key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
} }
# Move action to CPU # Move action to device
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda") transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
# No need to move reward or done, as they are float and bool # Move reward and done if they are tensors
# No need to move reward or done, as they are float and bool
if isinstance(transition["reward"], torch.Tensor): if isinstance(transition["reward"], torch.Tensor):
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda") transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
if isinstance(transition["done"], torch.Tensor): if isinstance(transition["done"], torch.Tensor):
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda") transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
if isinstance(transition["truncated"], torch.Tensor): if isinstance(transition["truncated"], torch.Tensor):
transition["truncated"] = transition["truncated"].to(device, non_blocking=device.type == "cuda") transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
# Move next_state tensors to CPU # Move next_state tensors to device
transition["next_state"] = { transition["next_state"] = {
key: val.to(device, non_blocking=device.type == "cuda") key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
for key, val in transition["next_state"].items()
} }
# If complementary_info is present, move its tensors to CPU # Move complementary_info tensors if present
# if transition["complementary_info"] is not None: if transition.get("complementary_info") is not None:
# transition["complementary_info"] = { for key, val in transition["complementary_info"].items():
# key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items() if isinstance(val, torch.Tensor):
# } transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
elif isinstance(val, (int, float, bool)):
transition["complementary_info"][key] = torch.tensor(
val, device=device, non_blocking=non_blocking
)
else:
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
return transition return transition
@ -217,7 +222,12 @@ class ReplayBuffer:
self.image_augmentation_function = torch.compile(base_function) self.image_augmentation_function = torch.compile(base_function)
self.use_drq = use_drq self.use_drq = use_drq
def _initialize_storage(self, state: dict[str, torch.Tensor], action: torch.Tensor): def _initialize_storage(
self,
state: dict[str, torch.Tensor],
action: torch.Tensor,
complementary_info: Optional[dict[str, torch.Tensor]] = None,
):
"""Initialize the storage tensors based on the first transition.""" """Initialize the storage tensors based on the first transition."""
# Determine shapes from the first transition # Determine shapes from the first transition
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()} state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
@ -244,6 +254,27 @@ class ReplayBuffer:
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
# Initialize storage for complementary_info
self.has_complementary_info = complementary_info is not None
self.complementary_info_keys = []
self.complementary_info = {}
if self.has_complementary_info:
self.complementary_info_keys = list(complementary_info.keys())
# Pre-allocate tensors for each key in complementary_info
for key, value in complementary_info.items():
if isinstance(value, torch.Tensor):
value_shape = value.squeeze(0).shape
self.complementary_info[key] = torch.empty(
(self.capacity, *value_shape), device=self.storage_device
)
elif isinstance(value, (int, float)):
# Handle scalar values similar to reward
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
else:
raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]")
self.initialized = True self.initialized = True
def __len__(self): def __len__(self):
@ -262,7 +293,7 @@ class ReplayBuffer:
"""Saves a transition, ensuring tensors are stored on the designated storage device.""" """Saves a transition, ensuring tensors are stored on the designated storage device."""
# Initialize storage if this is the first transition # Initialize storage if this is the first transition
if not self.initialized: if not self.initialized:
self._initialize_storage(state=state, action=action) self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
# Store the transition in pre-allocated tensors # Store the transition in pre-allocated tensors
for key in self.states: for key in self.states:
@ -277,6 +308,17 @@ class ReplayBuffer:
self.dones[self.position] = done self.dones[self.position] = done
self.truncateds[self.position] = truncated self.truncateds[self.position] = truncated
# Handle complementary_info if provided and storage is initialized
if complementary_info is not None and self.has_complementary_info:
# Store the complementary_info
for key in self.complementary_info_keys:
if key in complementary_info:
value = complementary_info[key]
if isinstance(value, torch.Tensor):
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
elif isinstance(value, (int, float)):
self.complementary_info[key][self.position] = value
self.position = (self.position + 1) % self.capacity self.position = (self.position + 1) % self.capacity
self.size = min(self.size + 1, self.capacity) self.size = min(self.size + 1, self.capacity)
@ -335,6 +377,13 @@ class ReplayBuffer:
batch_dones = self.dones[idx].to(self.device).float() batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float() batch_truncateds = self.truncateds[idx].to(self.device).float()
# Sample complementary_info if available
batch_complementary_info = None
if self.has_complementary_info:
batch_complementary_info = {}
for key in self.complementary_info_keys:
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
return BatchTransition( return BatchTransition(
state=batch_state, state=batch_state,
action=batch_actions, action=batch_actions,
@ -342,8 +391,112 @@ class ReplayBuffer:
next_state=batch_next_state, next_state=batch_next_state,
done=batch_dones, done=batch_dones,
truncated=batch_truncateds, truncated=batch_truncateds,
complementary_info=batch_complementary_info,
) )
def get_iterator(
self,
batch_size: int,
async_prefetch: bool = True,
queue_size: int = 2,
):
"""
Creates an infinite iterator that yields batches of transitions.
Will automatically restart when internal iterator is exhausted.
Args:
batch_size (int): Size of batches to sample
async_prefetch (bool): Whether to use asynchronous prefetching with threads (default: True)
queue_size (int): Number of batches to prefetch (default: 2)
Yields:
BatchTransition: Batched transitions
"""
while True: # Create an infinite loop
if async_prefetch:
# Get the standard iterator
iterator = self._get_async_iterator(queue_size=queue_size, batch_size=batch_size)
else:
iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size)
# Yield all items from the iterator
try:
yield from iterator
except StopIteration:
# Just continue the outer loop to create a new iterator
pass
def _get_async_iterator(self, batch_size: int, queue_size: int = 2):
"""
Creates an iterator that prefetches batches in a background thread.
Args:
queue_size (int): Number of batches to prefetch (default: 2)
batch_size (int): Size of batches to sample (default: 128)
Yields:
BatchTransition: Prefetched batch transitions
"""
import queue
import threading
# Use thread-safe queue
data_queue = queue.Queue(maxsize=queue_size)
running = [True] # Use list to allow modification in nested function
def prefetch_worker():
while running[0]:
try:
# Sample data and add to queue
data = self.sample(batch_size)
data_queue.put(data, block=True, timeout=0.5)
except queue.Full:
continue
except Exception as e:
print(f"Prefetch error: {e}")
break
# Start prefetching thread
thread = threading.Thread(target=prefetch_worker, daemon=True)
thread.start()
try:
while running[0]:
try:
yield data_queue.get(block=True, timeout=0.5)
except queue.Empty:
if not thread.is_alive():
break
finally:
# Clean up
running[0] = False
thread.join(timeout=1.0)
def _get_naive_iterator(self, batch_size: int, queue_size: int = 2):
"""
Creates a simple non-threaded iterator that yields batches.
Args:
batch_size (int): Size of batches to sample
queue_size (int): Number of initial batches to prefetch
Yields:
BatchTransition: Batch transitions
"""
import collections
queue = collections.deque()
def enqueue(n):
for _ in range(n):
data = self.sample(batch_size)
queue.append(data)
enqueue(queue_size)
while queue:
yield queue.popleft()
enqueue(1)
@classmethod @classmethod
def from_lerobot_dataset( def from_lerobot_dataset(
cls, cls,
@ -415,7 +568,19 @@ class ReplayBuffer:
if action_delta is not None: if action_delta is not None:
first_action = first_action / action_delta first_action = first_action / action_delta
replay_buffer._initialize_storage(state=first_state, action=first_action) # Get complementary info if available
first_complementary_info = None
if (
"complementary_info" in first_transition
and first_transition["complementary_info"] is not None
):
first_complementary_info = {
k: v.to(device) for k, v in first_transition["complementary_info"].items()
}
replay_buffer._initialize_storage(
state=first_state, action=first_action, complementary_info=first_complementary_info
)
# Fill the buffer with all transitions # Fill the buffer with all transitions
for data in list_transition: for data in list_transition:
@ -443,6 +608,7 @@ class ReplayBuffer:
next_state=data["next_state"], next_state=data["next_state"],
done=data["done"], done=data["done"],
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
complementary_info=data.get("complementary_info", None),
) )
return replay_buffer return replay_buffer
@ -484,6 +650,15 @@ class ReplayBuffer:
f_info = guess_feature_info(t=sample_val, name=key) f_info = guess_feature_info(t=sample_val, name=key)
features[key] = f_info features[key] = f_info
# Add complementary_info keys if available
if self.has_complementary_info:
for key in self.complementary_info_keys:
sample_val = self.complementary_info[key][0]
if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0:
sample_val = sample_val.unsqueeze(0)
f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}")
features[f"complementary_info.{key}"] = f_info
# Create an empty LeRobotDataset # Create an empty LeRobotDataset
lerobot_dataset = LeRobotDataset.create( lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id, repo_id=repo_id,
@ -517,6 +692,19 @@ class ReplayBuffer:
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
# Add complementary_info if available
if self.has_complementary_info:
for key in self.complementary_info_keys:
val = self.complementary_info[key][actual_idx]
# Convert tensors to CPU
if isinstance(val, torch.Tensor):
if val.ndim == 0:
val = val.unsqueeze(0)
frame_dict[f"complementary_info.{key}"] = val.cpu()
# Non-tensor values can be used directly
else:
frame_dict[f"complementary_info.{key}"] = val
# Add task field which is required by LeRobotDataset # Add task field which is required by LeRobotDataset
frame_dict["task"] = task_name frame_dict["task"] = task_name
@ -583,6 +771,10 @@ class ReplayBuffer:
sample = dataset[0] sample = dataset[0]
has_done_key = "next.done" in sample has_done_key = "next.done" in sample
# Check for complementary_info keys
complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")]
has_complementary_info = len(complementary_info_keys) > 0
# If not, we need to infer it from episode boundaries # If not, we need to infer it from episode boundaries
if not has_done_key: if not has_done_key:
print("'next.done' key not found in dataset. Inferring from episode boundaries...") print("'next.done' key not found in dataset. Inferring from episode boundaries...")
@ -632,6 +824,22 @@ class ReplayBuffer:
next_state_data[key] = val.unsqueeze(0) # Add batch dimension next_state_data[key] = val.unsqueeze(0) # Add batch dimension
next_state = next_state_data next_state = next_state_data
# ----- 5) Complementary info (if available) -----
complementary_info = None
if has_complementary_info:
complementary_info = {}
for key in complementary_info_keys:
# Strip the "complementary_info." prefix to get the actual key
clean_key = key[len("complementary_info.") :]
val = current_sample[key]
# Handle tensor and non-tensor values differently
if isinstance(val, torch.Tensor):
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
else:
# TODO: (azouitine) Check if it's necessary to convert to tensor
# For non-tensor values, use directly
complementary_info[clean_key] = val
# ----- Construct the Transition ----- # ----- Construct the Transition -----
transition = Transition( transition = Transition(
state=current_state, state=current_state,
@ -640,6 +848,7 @@ class ReplayBuffer:
next_state=next_state, next_state=next_state,
done=done, done=done,
truncated=truncated, truncated=truncated,
complementary_info=complementary_info,
) )
transitions.append(transition) transitions.append(transition)
@ -647,12 +856,13 @@ class ReplayBuffer:
# Utility function to guess shapes/dtypes from a tensor # Utility function to guess shapes/dtypes from a tensor
def guess_feature_info(t: torch.Tensor, name: str): def guess_feature_info(t, name: str):
""" """
Return a dictionary with the 'dtype' and 'shape' for a given tensor or array. Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value.
If it looks like a 3D (C,H,W) shape, we might consider it an 'image'. If it looks like a 3D (C,H,W) shape, we might consider it an 'image'.
Otherwise default to 'float32' for numeric. You can customize as needed. Otherwise default to appropriate dtype for numeric.
""" """
shape = tuple(t.shape) shape = tuple(t.shape)
# Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image' # Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image'
if len(shape) == 3 and shape[0] in [1, 3]: if len(shape) == 3 and shape[0] in [1, 3]:
@ -672,32 +882,33 @@ def concatenate_batch_transitions(
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
) -> BatchTransition: ) -> BatchTransition:
"""NOTE: Be careful it change the left_batch_transitions in place""" """NOTE: Be careful it change the left_batch_transitions in place"""
# Concatenate state fields
left_batch_transitions["state"] = { left_batch_transitions["state"] = {
key: torch.cat( key: torch.cat(
[ [left_batch_transitions["state"][key], right_batch_transition["state"][key]],
left_batch_transitions["state"][key],
right_batch_transition["state"][key],
],
dim=0, dim=0,
) )
for key in left_batch_transitions["state"] for key in left_batch_transitions["state"]
} }
# Concatenate basic fields
left_batch_transitions["action"] = torch.cat( left_batch_transitions["action"] = torch.cat(
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0 [left_batch_transitions["action"], right_batch_transition["action"]], dim=0
) )
left_batch_transitions["reward"] = torch.cat( left_batch_transitions["reward"] = torch.cat(
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0 [left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
) )
# Concatenate next_state fields
left_batch_transitions["next_state"] = { left_batch_transitions["next_state"] = {
key: torch.cat( key: torch.cat(
[ [left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]],
left_batch_transitions["next_state"][key],
right_batch_transition["next_state"][key],
],
dim=0, dim=0,
) )
for key in left_batch_transitions["next_state"] for key in left_batch_transitions["next_state"]
} }
# Concatenate done and truncated fields
left_batch_transitions["done"] = torch.cat( left_batch_transitions["done"] = torch.cat(
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0 [left_batch_transitions["done"], right_batch_transition["done"]], dim=0
) )
@ -705,479 +916,114 @@ def concatenate_batch_transitions(
[left_batch_transitions["truncated"], right_batch_transition["truncated"]], [left_batch_transitions["truncated"], right_batch_transition["truncated"]],
dim=0, dim=0,
) )
# Handle complementary_info
left_info = left_batch_transitions.get("complementary_info")
right_info = right_batch_transition.get("complementary_info")
# Only process if right_info exists
if right_info is not None:
# Initialize left complementary_info if needed
if left_info is None:
left_batch_transitions["complementary_info"] = right_info
else:
# Concatenate each field
for key in right_info:
if key in left_info:
left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0)
else:
left_info[key] = right_info[key]
return left_batch_transitions return left_batch_transitions
if __name__ == "__main__": if __name__ == "__main__":
from tempfile import TemporaryDirectory
# ===== Test 1: Create and use a synthetic ReplayBuffer ===== def test_load_dataset_with_complementary_info():
print("Testing synthetic ReplayBuffer...") """
Test loading a dataset with complementary_info into a ReplayBuffer.
The dataset 'aractingi/pick_lift_cube_two_cameras_gripper_penalty' contains
gripper_penalty values in complementary_info.
"""
import time
# Create sample data dimensions from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
batch_size = 32
state_dims = {"observation.image": (3, 84, 84), "observation.state": (10,)}
action_dim = (6,)
# Create a buffer print("Loading dataset with complementary info...")
buffer = ReplayBuffer( # Load a small subset of the dataset (first episode)
capacity=1000, dataset = LeRobotDataset(
device="cpu", repo_id="aractingi/pick_lift_cube_two_cameras_gripper_penalty",
state_keys=list(state_dims.keys()),
use_drq=True,
storage_device="cpu",
) )
# Add some random transitions print(f"Dataset loaded with {len(dataset)} frames")
for i in range(100):
# Create dummy transition data
state = {
"observation.image": torch.rand(1, 3, 84, 84),
"observation.state": torch.rand(1, 10),
}
action = torch.rand(1, 6)
reward = 0.5
next_state = {
"observation.image": torch.rand(1, 3, 84, 84),
"observation.state": torch.rand(1, 10),
}
done = False if i < 99 else True
truncated = False
buffer.add(
state=state,
action=action,
reward=reward,
next_state=next_state,
done=done,
truncated=truncated,
)
# Test sampling
batch = buffer.sample(batch_size)
print(f"Buffer size: {len(buffer)}")
print(
f"Sampled batch state shapes: {batch['state']['observation.image'].shape}, {batch['state']['observation.state'].shape}"
)
print(f"Sampled batch action shape: {batch['action'].shape}")
print(f"Sampled batch reward shape: {batch['reward'].shape}")
print(f"Sampled batch done shape: {batch['done'].shape}")
print(f"Sampled batch truncated shape: {batch['truncated'].shape}")
# ===== Test for state-action-reward alignment =====
print("\nTesting state-action-reward alignment...")
# Create a buffer with controlled transitions where we know the relationships
aligned_buffer = ReplayBuffer(
capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu"
)
# Create transitions with known relationships
# - Each state has a unique signature value
# - Action is 2x the state signature
# - Reward is 3x the state signature
# - Next state is signature + 0.01 (unless at episode end)
for i in range(100):
# Create a state with a signature value that encodes the transition number
signature = float(i) / 100.0
state = {"state_value": torch.tensor([[signature]]).float()}
# Action is 2x the signature
action = torch.tensor([[2.0 * signature]]).float()
# Reward is 3x the signature
reward = 3.0 * signature
# Next state is signature + 0.01, unless end of episode
# End episode every 10 steps
is_end = (i + 1) % 10 == 0
if is_end:
# At episode boundaries, next_state repeats current state (as per your implementation)
next_state = {"state_value": torch.tensor([[signature]]).float()}
done = True
else:
# Within episodes, next_state has signature + 0.01
next_signature = float(i + 1) / 100.0
next_state = {"state_value": torch.tensor([[next_signature]]).float()}
done = False
aligned_buffer.add(state, action, reward, next_state, done, False)
# Sample from this buffer
aligned_batch = aligned_buffer.sample(50)
# Verify alignments in sampled batch
correct_relationships = 0
total_checks = 0
# For each transition in the batch
for i in range(50):
# Extract signature from state
state_sig = aligned_batch["state"]["state_value"][i].item()
# Check action is 2x signature (within reasonable precision)
action_val = aligned_batch["action"][i].item()
action_check = abs(action_val - 2.0 * state_sig) < 1e-4
# Check reward is 3x signature (within reasonable precision)
reward_val = aligned_batch["reward"][i].item()
reward_check = abs(reward_val - 3.0 * state_sig) < 1e-4
# Check next_state relationship matches our pattern
next_state_sig = aligned_batch["next_state"]["state_value"][i].item()
is_done = aligned_batch["done"][i].item() > 0.5
# Calculate expected next_state value based on done flag
if is_done:
# For episodes that end, next_state should equal state
next_state_check = abs(next_state_sig - state_sig) < 1e-4
else:
# For continuing episodes, check if next_state is approximately state + 0.01
# We need to be careful because we don't know the original index
# So we check if the increment is roughly 0.01
next_state_check = (
abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4
)
# Count correct relationships
if action_check:
correct_relationships += 1
if reward_check:
correct_relationships += 1
if next_state_check:
correct_relationships += 1
total_checks += 3
alignment_accuracy = 100.0 * correct_relationships / total_checks
print(f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%")
if alignment_accuracy > 99.0:
print("✅ All relationships verified! Buffer maintains correct temporal relationships.")
else:
print("⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues.")
# Print some debug information about failures
print("\nDebug information for failed checks:")
for i in range(5): # Print first 5 transitions for debugging
state_sig = aligned_batch["state"]["state_value"][i].item()
action_val = aligned_batch["action"][i].item()
reward_val = aligned_batch["reward"][i].item()
next_state_sig = aligned_batch["next_state"]["state_value"][i].item()
is_done = aligned_batch["done"][i].item() > 0.5
print(f"Transition {i}:")
print(f" State: {state_sig:.6f}")
print(f" Action: {action_val:.6f} (expected: {2.0 * state_sig:.6f})")
print(f" Reward: {reward_val:.6f} (expected: {3.0 * state_sig:.6f})")
print(f" Done: {is_done}")
print(f" Next state: {next_state_sig:.6f}")
# Calculate expected next state
if is_done:
expected_next = state_sig
else:
# This approximation might not be perfect
state_idx = round(state_sig * 100)
expected_next = (state_idx + 1) / 100.0
print(f" Expected next state: {expected_next:.6f}")
print()
# ===== Test 2: Convert to LeRobotDataset and back =====
with TemporaryDirectory() as temp_dir:
print("\nTesting conversion to LeRobotDataset and back...")
# Convert buffer to dataset
repo_id = "test/replay_buffer_conversion"
# Create a subdirectory to avoid the "directory exists" error
dataset_dir = os.path.join(temp_dir, "dataset1")
dataset = buffer.to_lerobot_dataset(repo_id=repo_id, root=dataset_dir)
print(f"Dataset created with {len(dataset)} frames")
print(f"Dataset features: {list(dataset.features.keys())}") print(f"Dataset features: {list(dataset.features.keys())}")
# Check a random sample from the dataset # Check if dataset has complementary_info.gripper_penalty
sample = dataset[0] sample = dataset[0]
print( complementary_info_keys = [key for key in sample if key.startswith("complementary_info")]
f"Dataset sample types: {[(k, type(v)) for k, v in sample.items() if k.startswith('observation')]}" print(f"Complementary info keys: {complementary_info_keys}")
if "complementary_info.gripper_penalty" in sample:
print(f"Found gripper_penalty: {sample['complementary_info.gripper_penalty']}")
# Extract state keys for the buffer
state_keys = []
for key in sample:
if key.startswith("observation"):
state_keys.append(key)
print(f"Using state keys: {state_keys}")
# Create a replay buffer from the dataset
start_time = time.time()
buffer = ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset, state_keys=state_keys, use_drq=True, optimize_memory=False
) )
load_time = time.time() - start_time
print(f"Loaded dataset into buffer in {load_time:.2f} seconds")
print(f"Buffer size: {len(buffer)}")
# Convert dataset back to buffer # Check if complementary_info was transferred correctly
reconverted_buffer = ReplayBuffer.from_lerobot_dataset( print("Sampling from buffer to check complementary_info...")
dataset, state_keys=list(state_dims.keys()), device="cpu" batch = buffer.sample(batch_size=4)
)
print(f"Reconverted buffer size: {len(reconverted_buffer)}") if batch["complementary_info"] is not None:
print("Complementary info in batch:")
# Sample from the reconverted buffer for key, value in batch["complementary_info"].items():
reconverted_batch = reconverted_buffer.sample(batch_size) print(f" {key}: {type(value)}, shape: {value.shape if hasattr(value, 'shape') else 'N/A'}")
print( if key == "gripper_penalty":
f"Reconverted batch state shapes: {reconverted_batch['state']['observation.image'].shape}, {reconverted_batch['state']['observation.state'].shape}" print(f" Sample gripper_penalty values: {value[:5]}")
)
# Verify consistency before and after conversion
original_states = batch["state"]["observation.image"].mean().item()
reconverted_states = reconverted_batch["state"]["observation.image"].mean().item()
print(f"Original buffer state mean: {original_states:.4f}")
print(f"Reconverted buffer state mean: {reconverted_states:.4f}")
if abs(original_states - reconverted_states) < 1.0:
print("Values are reasonably similar - conversion works as expected")
else: else:
print("WARNING: Significant difference between original and reconverted values") print("No complementary_info found in batch")
print("\nAll previous tests completed!") # Now convert the buffer back to a LeRobotDataset
print("\nConverting buffer back to LeRobotDataset...")
# ===== Test for memory optimization ===== start_time = time.time()
print("\n===== Testing Memory Optimization =====") new_dataset = buffer.to_lerobot_dataset(
repo_id="test_dataset_from_buffer",
# Create two buffers, one with memory optimization and one without fps=dataset.fps,
standard_buffer = ReplayBuffer( root="./test_dataset_from_buffer",
capacity=1000, task_name="test_conversion",
device="cpu",
state_keys=["observation.image", "observation.state"],
storage_device="cpu",
optimize_memory=False,
use_drq=True,
) )
convert_time = time.time() - start_time
print(f"Converted buffer to dataset in {convert_time:.2f} seconds")
print(f"New dataset size: {len(new_dataset)} frames")
optimized_buffer = ReplayBuffer( # Check if complementary_info was preserved
capacity=1000, new_sample = new_dataset[0]
device="cpu", new_complementary_info_keys = [key for key in new_sample if key.startswith("complementary_info")]
state_keys=["observation.image", "observation.state"], print(f"New dataset complementary info keys: {new_complementary_info_keys}")
storage_device="cpu",
optimize_memory=True,
use_drq=True,
)
# Generate sample data with larger state dimensions for better memory impact if "complementary_info.gripper_penalty" in new_sample:
print("Generating test data...") print(f"Found gripper_penalty in new dataset: {new_sample['complementary_info.gripper_penalty']}")
num_episodes = 10
steps_per_episode = 50
total_steps = num_episodes * steps_per_episode
for episode in range(num_episodes): # Compare original and new datasets
for step in range(steps_per_episode): print("\nComparing original and new datasets:")
# Index in the overall sequence print(f"Original dataset frames: {len(dataset)}, New dataset frames: {len(new_dataset)}")
i = episode * steps_per_episode + step print(f"Original features: {list(dataset.features.keys())}")
print(f"New features: {list(new_dataset.features.keys())}")
# Create state with identifiable values return buffer, dataset, new_dataset
img = torch.ones((3, 84, 84)) * (i / total_steps)
state_vec = torch.ones((10,)) * (i / total_steps)
state = { # Run the test
"observation.image": img.unsqueeze(0), test_load_dataset_with_complementary_info()
"observation.state": state_vec.unsqueeze(0),
}
# Create next state (i+1 or same as current if last in episode)
is_last_step = step == steps_per_episode - 1
if is_last_step:
# At episode end, next state = current state
next_img = img.clone()
next_state_vec = state_vec.clone()
done = True
truncated = False
else:
# Within episode, next state has incremented value
next_val = (i + 1) / total_steps
next_img = torch.ones((3, 84, 84)) * next_val
next_state_vec = torch.ones((10,)) * next_val
done = False
truncated = False
next_state = {
"observation.image": next_img.unsqueeze(0),
"observation.state": next_state_vec.unsqueeze(0),
}
# Action and reward
action = torch.tensor([[i / total_steps]])
reward = float(i / total_steps)
# Add to both buffers
standard_buffer.add(state, action, reward, next_state, done, truncated)
optimized_buffer.add(state, action, reward, next_state, done, truncated)
# Verify episode boundaries with our simplified approach
print("\nVerifying simplified memory optimization...")
# Test with a new buffer with a small sequence
test_buffer = ReplayBuffer(
capacity=20,
device="cpu",
state_keys=["value"],
storage_device="cpu",
optimize_memory=True,
use_drq=False,
)
# Add a simple sequence with known episode boundaries
for i in range(20):
val = float(i)
state = {"value": torch.tensor([[val]]).float()}
next_val = float(i + 1) if i % 5 != 4 else val # Episode ends every 5 steps
next_state = {"value": torch.tensor([[next_val]]).float()}
# Set done=True at every 5th step
done = (i % 5) == 4
action = torch.tensor([[0.0]])
reward = 1.0
truncated = False
test_buffer.add(state, action, reward, next_state, done, truncated)
# Get sequential batch for verification
sequential_batch_size = test_buffer.size
all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device)
# Get state tensors
batch_state = {"value": test_buffer.states["value"][all_indices].to(test_buffer.device)}
# Get next_state using memory-optimized approach (simply index+1)
next_indices = (all_indices + 1) % test_buffer.capacity
batch_next_state = {"value": test_buffer.states["value"][next_indices].to(test_buffer.device)}
# Get other tensors
batch_dones = test_buffer.dones[all_indices].to(test_buffer.device)
# Print sequential values
print("State, Next State, Done (Sequential values with simplified optimization):")
state_values = batch_state["value"].squeeze().tolist()
next_values = batch_next_state["value"].squeeze().tolist()
done_flags = batch_dones.tolist()
# Print all values
for i in range(len(state_values)):
print(f" {state_values[i]:.1f}{next_values[i]:.1f}, Done: {done_flags[i]}")
# Explain the memory optimization tradeoff
print("\nWith simplified memory optimization:")
print("- We always use the next state in the buffer (index+1) as next_state")
print("- For terminal states, this means using the first state of the next episode")
print("- This is a common tradeoff in RL implementations for memory efficiency")
print("- Since we track done flags, the algorithm can handle these transitions correctly")
# Test random sampling
print("\nVerifying random sampling with simplified memory optimization...")
random_samples = test_buffer.sample(20) # Sample all transitions
# Extract values
random_state_values = random_samples["state"]["value"].squeeze().tolist()
random_next_values = random_samples["next_state"]["value"].squeeze().tolist()
random_done_flags = random_samples["done"].bool().tolist()
# Print a few samples
print("Random samples - State, Next State, Done (First 10):")
for i in range(10):
print(f" {random_state_values[i]:.1f}{random_next_values[i]:.1f}, Done: {random_done_flags[i]}")
# Calculate memory savings
# Assume optimized_buffer and standard_buffer have already been initialized and filled
std_mem = (
sum(
standard_buffer.states[key].nelement() * standard_buffer.states[key].element_size()
for key in standard_buffer.states
)
* 2
)
opt_mem = sum(
optimized_buffer.states[key].nelement() * optimized_buffer.states[key].element_size()
for key in optimized_buffer.states
)
savings_percent = (std_mem - opt_mem) / std_mem * 100
print("\nMemory optimization result:")
print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB")
print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB")
print(f"- Memory savings for state tensors: {savings_percent:.1f}%")
print("\nAll memory optimization tests completed!")
# # ===== Test real dataset conversion =====
# print("\n===== Testing Real LeRobotDataset Conversion =====")
# try:
# # Try to use a real dataset if available
# dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small"
# dataset = LeRobotDataset(repo_id=dataset_name)
# # Print available keys to debug
# sample = dataset[0]
# print("Available keys in dataset:", list(sample.keys()))
# # Check for required keys
# if "action" not in sample or "next.reward" not in sample:
# print("Dataset missing essential keys. Cannot convert.")
# raise ValueError("Missing required keys in dataset")
# # Auto-detect appropriate state keys
# image_keys = []
# state_keys = []
# for k, v in sample.items():
# # Skip metadata keys and action/reward keys
# if k in {
# "index",
# "episode_index",
# "frame_index",
# "timestamp",
# "task_index",
# "action",
# "next.reward",
# "next.done",
# }:
# continue
# # Infer key type from tensor shape
# if isinstance(v, torch.Tensor):
# if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1):
# # Likely an image (channels, height, width)
# image_keys.append(k)
# else:
# # Likely state or other vector
# state_keys.append(k)
# print(f"Detected image keys: {image_keys}")
# print(f"Detected state keys: {state_keys}")
# if not image_keys and not state_keys:
# print("No usable keys found in dataset, skipping further tests")
# raise ValueError("No usable keys found in dataset")
# # Test with standard and memory-optimized buffers
# for optimize_memory in [False, True]:
# buffer_type = "Standard" if not optimize_memory else "Memory-optimized"
# print(f"\nTesting {buffer_type} buffer with real dataset...")
# # Convert to ReplayBuffer with detected keys
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
# lerobot_dataset=dataset,
# state_keys=image_keys + state_keys,
# device="cpu",
# optimize_memory=optimize_memory,
# )
# print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}")
# # Test sampling
# real_batch = replay_buffer.sample(32)
# print(f"Sampled batch from real dataset ({buffer_type}), state shapes:")
# for key in real_batch["state"]:
# print(f" {key}: {real_batch['state'][key].shape}")
# # Convert back to LeRobotDataset
# with TemporaryDirectory() as temp_dir:
# dataset_name = f"test/real_dataset_converted_{buffer_type}"
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(
# repo_id=dataset_name,
# root=os.path.join(temp_dir, f"dataset_{buffer_type}"),
# )
# print(
# f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames"
# )
# except Exception as e:
# print(f"Real dataset test failed: {e}")
# print("This is expected if running offline or if the dataset is not available.")
# print("\nAll tests completed!")

View File

@ -761,6 +761,62 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
return observation return observation
class GripperPenaltyWrapper(gym.RewardWrapper):
def __init__(self, env, penalty: float = -0.1):
super().__init__(env)
self.penalty = penalty
self.last_gripper_state = None
def reward(self, reward, action):
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
if isinstance(action, tuple):
action = action[0]
action_normalized = action[-1] / MAX_GRIPPER_COMMAND
gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or (
gripper_state_normalized > 0.9 and action_normalized < 0.1
)
breakpoint()
return reward + self.penalty * gripper_penalty_bool
def step(self, action):
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
obs, reward, terminated, truncated, info = self.env.step(action)
reward = self.reward(reward, action)
return obs, reward, terminated, truncated, info
def reset(self, **kwargs):
self.last_gripper_state = None
return super().reset(**kwargs)
class GripperQuantizationWrapper(gym.ActionWrapper):
def __init__(self, env, quantization_threshold: float = 0.2):
super().__init__(env)
self.quantization_threshold = quantization_threshold
def action(self, action):
is_intervention = False
if isinstance(action, tuple):
action, is_intervention = action
gripper_command = action[-1]
# Quantize gripper command to -1, 0 or 1
if gripper_command < -self.quantization_threshold:
gripper_command = -MAX_GRIPPER_COMMAND
elif gripper_command > self.quantization_threshold:
gripper_command = MAX_GRIPPER_COMMAND
else:
gripper_command = 0.0
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
action[-1] = gripper_action.item()
return action, is_intervention
class EEActionWrapper(gym.ActionWrapper): class EEActionWrapper(gym.ActionWrapper):
def __init__(self, env, ee_action_space_params=None, use_gripper=False): def __init__(self, env, ee_action_space_params=None, use_gripper=False):
super().__init__(env) super().__init__(env)
@ -820,17 +876,7 @@ class EEActionWrapper(gym.ActionWrapper):
fk_func=self.fk_function, fk_func=self.fk_function,
) )
if self.use_gripper: if self.use_gripper:
# Quantize gripper command to -1, 0 or 1 target_joint_pos[-1] = gripper_command
if gripper_command < -0.2:
gripper_command = -1.0
elif gripper_command > 0.2:
gripper_command = 1.0
else:
gripper_command = 0.0
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
target_joint_pos[-1] = gripper_action
return target_joint_pos, is_intervention return target_joint_pos, is_intervention
@ -1118,6 +1164,12 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
# Add reward computation and control wrappers # Add reward computation and control wrappers
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.wrapper.use_gripper:
env = GripperQuantizationWrapper(
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
)
# env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
if cfg.wrapper.ee_action_space_params is not None: if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper( env = EEActionWrapper(
env=env, env=env,

View File

@ -1,4 +1,4 @@
#!/usr/bin/env python # !/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. # Copyright 2024 The HuggingFace Inc. team.
# All rights reserved. # All rights reserved.
@ -269,6 +269,7 @@ def add_actor_information_and_train(
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
saving_checkpoint = cfg.save_checkpoint saving_checkpoint = cfg.save_checkpoint
online_steps = cfg.policy.online_steps online_steps = cfg.policy.online_steps
async_prefetch = cfg.policy.async_prefetch
# Initialize logging for multiprocessing # Initialize logging for multiprocessing
if not use_threads(cfg): if not use_threads(cfg):
@ -326,6 +327,9 @@ def add_actor_information_and_train(
if cfg.dataset is not None: if cfg.dataset is not None:
dataset_repo_id = cfg.dataset.repo_id dataset_repo_id = cfg.dataset.repo_id
# Initialize iterators
online_iterator = None
offline_iterator = None
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER # NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
while True: while True:
# Exit the training loop if shutdown is requested # Exit the training loop if shutdown is requested
@ -359,13 +363,26 @@ def add_actor_information_and_train(
if len(replay_buffer) < online_step_before_learning: if len(replay_buffer) < online_step_before_learning:
continue continue
if online_iterator is None:
logging.debug("[LEARNER] Initializing online replay buffer iterator")
online_iterator = replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
if offline_replay_buffer is not None and offline_iterator is None:
logging.debug("[LEARNER] Initializing offline replay buffer iterator")
offline_iterator = offline_replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
logging.debug("[LEARNER] Starting optimization loop") logging.debug("[LEARNER] Starting optimization loop")
time_for_one_optimization_step = time.time() time_for_one_optimization_step = time.time()
for _ in range(utd_ratio - 1): for _ in range(utd_ratio - 1):
batch = replay_buffer.sample(batch_size=batch_size) # Sample from the iterators
batch = next(online_iterator)
if dataset_repo_id is not None: if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size=batch_size) batch_offline = next(offline_iterator)
batch = concatenate_batch_transitions( batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline left_batch_transitions=batch, right_batch_transition=batch_offline
) )
@ -392,24 +409,37 @@ def add_actor_information_and_train(
"next_observation_feature": next_observation_features, "next_observation_feature": next_observation_features,
} }
# Use the forward method for critic loss # Use the forward method for critic loss (includes both main critic and grasp critic)
loss_critic = policy.forward(forward_batch, model="critic") critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad() optimizers["critic"].zero_grad()
loss_critic.backward() loss_critic.backward()
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_( critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
) )
optimizers["critic"].step() optimizers["critic"].step()
# Grasp critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
)
optimizers["grasp_critic"].step()
# Update target networks
policy.update_target_networks() policy.update_target_networks()
batch = replay_buffer.sample(batch_size=batch_size) # Sample for the last update in the UTD ratio
batch = next(online_iterator)
if dataset_repo_id is not None: if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size=batch_size) batch_offline = next(offline_iterator)
batch = concatenate_batch_transitions( batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline left_batch_transitions=batch, right_batch_transition=batch_offline
) )
@ -437,63 +467,80 @@ def add_actor_information_and_train(
"next_observation_feature": next_observation_features, "next_observation_feature": next_observation_features,
} }
# Use the forward method for critic loss # Use the forward method for critic loss (includes both main critic and grasp critic)
loss_critic = policy.forward(forward_batch, model="critic") critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad() optimizers["critic"].zero_grad()
loss_critic.backward() loss_critic.backward()
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_( critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
).item() ).item()
optimizers["critic"].step() optimizers["critic"].step()
training_infos = {} # Initialize training info dictionary
training_infos["loss_critic"] = loss_critic.item() training_infos = {
training_infos["critic_grad_norm"] = critic_grad_norm "loss_critic": loss_critic.item(),
"critic_grad_norm": critic_grad_norm,
}
# Grasp critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
).item()
optimizers["grasp_critic"].step()
# Add grasp critic info to training info
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
# Actor and temperature optimization (at specified frequency)
if optimization_step % policy_update_freq == 0: if optimization_step % policy_update_freq == 0:
for _ in range(policy_update_freq): for _ in range(policy_update_freq):
# Use the forward method for actor loss # Actor optimization
loss_actor = policy.forward(forward_batch, model="actor") actor_output = policy.forward(forward_batch, model="actor")
loss_actor = actor_output["loss_actor"]
optimizers["actor"].zero_grad() optimizers["actor"].zero_grad()
loss_actor.backward() loss_actor.backward()
# clip gradients
actor_grad_norm = torch.nn.utils.clip_grad_norm_( actor_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
).item() ).item()
optimizers["actor"].step() optimizers["actor"].step()
# Add actor info to training info
training_infos["loss_actor"] = loss_actor.item() training_infos["loss_actor"] = loss_actor.item()
training_infos["actor_grad_norm"] = actor_grad_norm training_infos["actor_grad_norm"] = actor_grad_norm
# Temperature optimization using forward method # Temperature optimization
loss_temperature = policy.forward(forward_batch, model="temperature") temperature_output = policy.forward(forward_batch, model="temperature")
loss_temperature = temperature_output["loss_temperature"]
optimizers["temperature"].zero_grad() optimizers["temperature"].zero_grad()
loss_temperature.backward() loss_temperature.backward()
# clip gradients
temp_grad_norm = torch.nn.utils.clip_grad_norm_( temp_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
).item() ).item()
optimizers["temperature"].step() optimizers["temperature"].step()
# Add temperature info to training info
training_infos["loss_temperature"] = loss_temperature.item() training_infos["loss_temperature"] = loss_temperature.item()
training_infos["temperature_grad_norm"] = temp_grad_norm training_infos["temperature_grad_norm"] = temp_grad_norm
training_infos["temperature"] = policy.temperature training_infos["temperature"] = policy.temperature
# Update temperature
policy.update_temperature() policy.update_temperature()
# Check if it's time to push updated policy to actors # Push policy to actors if needed
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
last_time_policy_pushed = time.time() last_time_policy_pushed = time.time()
# Update target networks
policy.update_target_networks() policy.update_target_networks()
# Log training metrics at specified intervals # Log training metrics at specified intervals
@ -697,7 +744,7 @@ def save_training_checkpoint(
logging.info("Resume training") logging.info("Resume training")
def make_optimizers_and_scheduler(cfg, policy: nn.Module): def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
""" """
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy. Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
@ -728,7 +775,14 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
params=policy.actor.parameters_to_optimize, params=policy.actor.parameters_to_optimize,
lr=cfg.policy.actor_lr, lr=cfg.policy.actor_lr,
) )
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr
)
if cfg.policy.num_discrete_actions is not None:
optimizer_grasp_critic = torch.optim.Adam(
params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None lr_scheduler = None
optimizers = { optimizers = {
@ -736,6 +790,8 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
"critic": optimizer_critic, "critic": optimizer_critic,
"temperature": optimizer_temperature, "temperature": optimizer_temperature,
} }
if cfg.policy.num_discrete_actions is not None:
optimizers["grasp_critic"] = optimizer_grasp_critic
return optimizers, lr_scheduler return optimizers, lr_scheduler
@ -970,12 +1026,8 @@ def get_observation_features(
return None, None return None, None
with torch.no_grad(): with torch.no_grad():
observation_features = ( observation_features = policy.actor.encoder.get_image_features(observations)
policy.actor.encoder(observations) if policy.actor.encoder is not None else None next_observation_features = policy.actor.encoder.get_image_features(next_observations)
)
next_observation_features = (
policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None
)
return observation_features, next_observation_features return observation_features, next_observation_features

View File

@ -1,5 +1,3 @@
import logging
import time
from typing import Any from typing import Any
import einops import einops
@ -10,7 +8,6 @@ from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from lerobot.common.envs.configs import ManiskillEnvConfig from lerobot.common.envs.configs import ManiskillEnvConfig
from lerobot.configs import parser
def preprocess_maniskill_observation( def preprocess_maniskill_observation(
@ -153,6 +150,27 @@ class TimeLimitWrapper(gym.Wrapper):
return super().reset(seed=seed, options=options) return super().reset(seed=seed, options=options)
class ManiskillMockGripperWrapper(gym.Wrapper):
def __init__(self, env, nb_discrete_actions: int = 3):
super().__init__(env)
new_shape = env.action_space[0].shape[0] + 1
new_low = np.concatenate([env.action_space[0].low, [0]])
new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]])
action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,))
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
def step(self, action):
if isinstance(action, tuple):
action_agent, telop_action = action
else:
telop_action = 0
action_agent = action
real_action = action_agent[:-1]
final_action = (real_action, telop_action)
obs, reward, terminated, truncated, info = self.env.step(final_action)
return obs, reward, terminated, truncated, info
def make_maniskill( def make_maniskill(
cfg: ManiskillEnvConfig, cfg: ManiskillEnvConfig,
n_envs: int | None = None, n_envs: int | None = None,
@ -197,40 +215,42 @@ def make_maniskill(
env = ManiSkillCompat(env) env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env) env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
if cfg.mock_gripper:
env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3)
return env return env
@parser.wrap() # @parser.wrap()
def main(cfg: ManiskillEnvConfig): # def main(cfg: TrainPipelineConfig):
"""Main function to run the ManiSkill environment.""" # """Main function to run the ManiSkill environment."""
# Create the ManiSkill environment # # Create the ManiSkill environment
env = make_maniskill(cfg, n_envs=1) # env = make_maniskill(cfg.env, n_envs=1)
# Reset the environment # # Reset the environment
obs, info = env.reset() # obs, info = env.reset()
# Run a simple interaction loop # # Run a simple interaction loop
sum_reward = 0 # sum_reward = 0
for i in range(100): # for i in range(100):
# Sample a random action # # Sample a random action
action = env.action_space.sample() # action = env.action_space.sample()
# Step the environment # # Step the environment
start_time = time.perf_counter() # start_time = time.perf_counter()
obs, reward, terminated, truncated, info = env.step(action) # obs, reward, terminated, truncated, info = env.step(action)
step_time = time.perf_counter() - start_time # step_time = time.perf_counter() - start_time
sum_reward += reward # sum_reward += reward
# Log information # # Log information
# Reset if episode terminated # # Reset if episode terminated
if terminated or truncated: # if terminated or truncated:
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s") # logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
sum_reward = 0 # sum_reward = 0
obs, info = env.reset() # obs, info = env.reset()
# Close the environment # # Close the environment
env.close() # env.close()
# if __name__ == "__main__": # if __name__ == "__main__":