diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index 825fa162..a6eda93b 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -203,6 +203,9 @@ class EnvWrapperConfig: joint_masking_action_space: Optional[Any] = None ee_action_space_params: Optional[EEActionSpaceConfig] = None use_gripper: bool = False + gripper_quantization_threshold: float = 0.8 + gripper_penalty: float = 0.0 + open_gripper_on_reset: bool = False @EnvConfig.register_subclass(name="gym_manipulator") @@ -254,6 +257,7 @@ class ManiskillEnvConfig(EnvConfig): robot: str = "so100" # This is a hack to make the robot config work video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig) wrapper: WrapperConfig = field(default_factory=WrapperConfig) + mock_gripper: bool = False features: dict[str, PolicyFeature] = field( default_factory=lambda: { "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 906a3bed..3d01f47c 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -85,12 +85,14 @@ class SACConfig(PreTrainedConfig): freeze_vision_encoder: Whether to freeze the vision encoder during training. image_encoder_hidden_dim: Hidden dimension size for the image encoder. shared_encoder: Whether to use a shared encoder for actor and critic. + num_discrete_actions: Number of discrete actions, eg for gripper actions. concurrency: Configuration for concurrency settings. actor_learner: Configuration for actor-learner architecture. online_steps: Number of steps for online training. online_env_seed: Seed for the online environment. online_buffer_capacity: Capacity of the online replay buffer. offline_buffer_capacity: Capacity of the offline replay buffer. + async_prefetch: Whether to use asynchronous prefetching for the buffers. online_step_before_learning: Number of steps before learning starts. policy_update_freq: Frequency of policy updates. discount: Discount factor for the SAC algorithm. @@ -144,12 +146,14 @@ class SACConfig(PreTrainedConfig): freeze_vision_encoder: bool = True image_encoder_hidden_dim: int = 32 shared_encoder: bool = True + num_discrete_actions: int | None = None # Training parameter online_steps: int = 1000000 online_env_seed: int = 10000 online_buffer_capacity: int = 100000 offline_buffer_capacity: int = 100000 + async_prefetch: bool = False online_step_before_learning: int = 100 policy_update_freq: int = 1 @@ -173,7 +177,7 @@ class SACConfig(PreTrainedConfig): critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig) policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig) - + grasp_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig) concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index f7866714..e3d83d36 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -33,6 +33,8 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.utils import get_device_from_parameters +DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension + class SACPolicy( PreTrainedPolicy, @@ -49,6 +51,8 @@ class SACPolicy( config.validate_features() self.config = config + continuous_action_dim = config.output_features["action"].shape[0] + if config.dataset_stats is not None: input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) self.normalize_inputs = Normalize( @@ -77,11 +81,12 @@ class SACPolicy( else: encoder_critic = SACObservationEncoder(config, self.normalize_inputs) encoder_actor = SACObservationEncoder(config, self.normalize_inputs) + self.shared_encoder = config.shared_encoder # Create a list of critic heads critic_heads = [ CriticHead( - input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0], + input_dim=encoder_critic.output_dim + continuous_action_dim, **asdict(config.critic_network_kwargs), ) for _ in range(config.num_critics) @@ -96,7 +101,7 @@ class SACPolicy( # Create target critic heads as deepcopies of the original critic heads target_critic_heads = [ CriticHead( - input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0], + input_dim=encoder_critic.output_dim + continuous_action_dim, **asdict(config.critic_network_kwargs), ) for _ in range(config.num_critics) @@ -112,15 +117,41 @@ class SACPolicy( self.critic_ensemble = torch.compile(self.critic_ensemble) self.critic_target = torch.compile(self.critic_target) + + self.grasp_critic = None + self.grasp_critic_target = None + + if config.num_discrete_actions is not None: + # Create grasp critic + self.grasp_critic = GraspCritic( + encoder=encoder_critic, + input_dim=encoder_critic.output_dim, + output_dim=config.num_discrete_actions, + **asdict(config.grasp_critic_network_kwargs), + ) + + # Create target grasp critic + self.grasp_critic_target = GraspCritic( + encoder=encoder_critic, + input_dim=encoder_critic.output_dim, + output_dim=config.num_discrete_actions, + **asdict(config.grasp_critic_network_kwargs), + ) + + self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict()) + + self.grasp_critic = torch.compile(self.grasp_critic) + self.grasp_critic_target = torch.compile(self.grasp_critic_target) + self.actor = Policy( encoder=encoder_actor, network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)), - action_dim=config.output_features["action"].shape[0], + action_dim=continuous_action_dim, encoder_is_shared=config.shared_encoder, **asdict(config.policy_kwargs), ) if config.target_entropy is None: - config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2) + config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2) # TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise @@ -131,11 +162,14 @@ class SACPolicy( self.temperature = self.log_alpha.exp().item() def get_optim_params(self) -> dict: - return { + optim_params = { "actor": self.actor.parameters_to_optimize, "critic": self.critic_ensemble.parameters_to_optimize, "temperature": self.log_alpha, } + if self.config.num_discrete_actions is not None: + optim_params["grasp_critic"] = self.grasp_critic.parameters_to_optimize + return optim_params def reset(self): """Reset the policy""" @@ -151,8 +185,19 @@ class SACPolicy( @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select action for inference/evaluation""" - actions, _, _ = self.actor(batch) + # We cached the encoder output to avoid recomputing it + observations_features = None + if self.shared_encoder: + observations_features = self.actor.encoder.get_image_features(batch) + + actions, _, _ = self.actor(batch, observations_features) actions = self.unnormalize_outputs({"action": actions})["action"] + + if self.config.num_discrete_actions is not None: + discrete_action_value = self.grasp_critic(batch, observations_features) + discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True) + actions = torch.cat([actions, discrete_action], dim=-1) + return actions def critic_forward( @@ -172,14 +217,30 @@ class SACPolicy( Returns: Tensor of Q-values from all critics """ + critics = self.critic_target if use_target else self.critic_ensemble q_values = critics(observations, actions, observation_features) return q_values + def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor: + """Forward pass through a grasp critic network + + Args: + observations: Dictionary of observations + use_target: If True, use target critics, otherwise use ensemble critics + observation_features: Optional pre-computed observation features to avoid recomputing encoder output + + Returns: + Tensor of Q-values from the grasp critic network + """ + grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic + q_values = grasp_critic(observations, observation_features) + return q_values + def forward( self, batch: dict[str, Tensor | dict[str, Tensor]], - model: Literal["actor", "critic", "temperature"] = "critic", + model: Literal["actor", "critic", "temperature", "grasp_critic"] = "critic", ) -> dict[str, Tensor]: """Compute the loss for the given model @@ -192,12 +253,11 @@ class SACPolicy( - done: Done mask tensor - observation_feature: Optional pre-computed observation features - next_observation_feature: Optional pre-computed next observation features - model: Which model to compute the loss for ("actor", "critic", or "temperature") + model: Which model to compute the loss for ("actor", "critic", "grasp_critic", or "temperature") Returns: The computed loss tensor """ - # TODO: (maractingi, azouitine) Respect the function signature we output tensors # Extract common components from batch actions: Tensor = batch["action"] observations: dict[str, Tensor] = batch["state"] @@ -210,7 +270,7 @@ class SACPolicy( done: Tensor = batch["done"] next_observation_features: Tensor = batch.get("next_observation_feature") - return self.compute_loss_critic( + loss_critic = self.compute_loss_critic( observations=observations, actions=actions, rewards=rewards, @@ -220,17 +280,41 @@ class SACPolicy( next_observation_features=next_observation_features, ) - if model == "actor": - return self.compute_loss_actor( + return {"loss_critic": loss_critic} + + if model == "grasp_critic" and self.config.num_discrete_actions is not None: + # Extract critic-specific components + rewards: Tensor = batch["reward"] + next_observations: dict[str, Tensor] = batch["next_state"] + done: Tensor = batch["done"] + next_observation_features: Tensor = batch.get("next_observation_feature") + complementary_info = batch.get("complementary_info") + loss_grasp_critic = self.compute_loss_grasp_critic( observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, observation_features=observation_features, + next_observation_features=next_observation_features, + complementary_info=complementary_info, ) + return {"loss_grasp_critic": loss_grasp_critic} + if model == "actor": + return { + "loss_actor": self.compute_loss_actor( + observations=observations, + observation_features=observation_features, + ) + } if model == "temperature": - return self.compute_loss_temperature( - observations=observations, - observation_features=observation_features, - ) + return { + "loss_temperature": self.compute_loss_temperature( + observations=observations, + observation_features=observation_features, + ) + } raise ValueError(f"Unknown model type: {model}") @@ -245,6 +329,16 @@ class SACPolicy( param.data * self.config.critic_target_update_weight + target_param.data * (1.0 - self.config.critic_target_update_weight) ) + if self.config.num_discrete_actions is not None: + for target_param, param in zip( + self.grasp_critic_target.parameters(), + self.grasp_critic.parameters(), + strict=False, + ): + target_param.data.copy_( + param.data * self.config.critic_target_update_weight + + target_param.data * (1.0 - self.config.critic_target_update_weight) + ) def update_temperature(self): self.temperature = self.log_alpha.exp().item() @@ -287,6 +381,11 @@ class SACPolicy( td_target = rewards + (1 - done) * self.config.discount * min_q # 3- compute predicted qs + if self.config.num_discrete_actions is not None: + # NOTE: We only want to keep the continuous action part + # In the buffer we have the full action space (continuous + discrete) + # We need to split them before concatenating them in the critic forward + actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] q_preds = self.critic_forward( observations=observations, actions=actions, @@ -307,6 +406,64 @@ class SACPolicy( ).sum() return critics_loss + def compute_loss_grasp_critic( + self, + observations, + actions, + rewards, + next_observations, + done, + observation_features=None, + next_observation_features=None, + complementary_info=None, + ): + # NOTE: We only want to keep the discrete action part + # In the buffer we have the full action space (continuous + discrete) + # We need to split them before concatenating them in the critic forward + actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() + actions_discrete = torch.round(actions_discrete) + actions_discrete = actions_discrete.long() + + if complementary_info is not None: + gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty") + + with torch.no_grad(): + # For DQN, select actions using online network, evaluate with target network + next_grasp_qs = self.grasp_critic_forward( + next_observations, use_target=False, observation_features=next_observation_features + ) + best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True) + + # Get target Q-values from target network + target_next_grasp_qs = self.grasp_critic_forward( + observations=next_observations, + use_target=True, + observation_features=next_observation_features, + ) + + # Use gather to select Q-values for best actions + target_next_grasp_q = torch.gather( + target_next_grasp_qs, dim=1, index=best_next_grasp_action + ).squeeze(-1) + + # Compute target Q-value with Bellman equation + rewards_gripper = rewards + if gripper_penalties is not None: + rewards_gripper = rewards + gripper_penalties + target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q + + # Get predicted Q-values for current observations + predicted_grasp_qs = self.grasp_critic_forward( + observations=observations, use_target=False, observation_features=observation_features + ) + + # Use gather to select Q-values for taken actions + predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions_discrete).squeeze(-1) + + # Compute MSE loss between predicted and target Q-values + grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q) + return grasp_critic_loss + def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: """Compute the temperature loss""" # calculate temperature loss @@ -337,6 +494,109 @@ class SACPolicy( return actor_loss +class SACObservationEncoder(nn.Module): + """Encode image and/or state vector observations.""" + + def __init__(self, config: SACConfig, input_normalizer: nn.Module): + """ + Creates encoders for pixel and/or state modalities. + """ + super().__init__() + self.config = config + self.input_normalization = input_normalizer + self.has_pretrained_vision_encoder = False + self.parameters_to_optimize = [] + + self.aggregation_size: int = 0 + if any("observation.image" in key for key in config.input_features): + self.camera_number = config.camera_number + + if self.config.vision_encoder_name is not None: + self.image_enc_layers = PretrainedImageEncoder(config) + self.has_pretrained_vision_encoder = True + else: + self.image_enc_layers = DefaultImageEncoder(config) + + self.aggregation_size += config.latent_dim * self.camera_number + + if config.freeze_vision_encoder: + freeze_image_encoder(self.image_enc_layers) + else: + self.parameters_to_optimize += list(self.image_enc_layers.parameters()) + self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")] + + if "observation.state" in config.input_features: + self.state_enc_layers = nn.Sequential( + nn.Linear( + in_features=config.input_features["observation.state"].shape[0], + out_features=config.latent_dim, + ), + nn.LayerNorm(normalized_shape=config.latent_dim), + nn.Tanh(), + ) + self.aggregation_size += config.latent_dim + + self.parameters_to_optimize += list(self.state_enc_layers.parameters()) + + if "observation.environment_state" in config.input_features: + self.env_state_enc_layers = nn.Sequential( + nn.Linear( + in_features=config.input_features["observation.environment_state"].shape[0], + out_features=config.latent_dim, + ), + nn.LayerNorm(normalized_shape=config.latent_dim), + nn.Tanh(), + ) + self.aggregation_size += config.latent_dim + self.parameters_to_optimize += list(self.env_state_enc_layers.parameters()) + + self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) + self.parameters_to_optimize += list(self.aggregation_layer.parameters()) + + def forward( + self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None + ) -> Tensor: + """Encode the image and/or state vector. + + Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken + over all features. + """ + feat = [] + obs_dict = self.input_normalization(obs_dict) + if len(self.all_image_keys) > 0 and vision_encoder_cache is None: + vision_encoder_cache = self.get_image_features(obs_dict) + feat.append(vision_encoder_cache) + + if vision_encoder_cache is not None: + feat.append(vision_encoder_cache) + + if "observation.environment_state" in self.config.input_features: + feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) + if "observation.state" in self.config.input_features: + feat.append(self.state_enc_layers(obs_dict["observation.state"])) + + features = torch.cat(tensors=feat, dim=-1) + features = self.aggregation_layer(features) + + return features + + def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor: + # [N*B, C, H, W] + if len(self.all_image_keys) > 0: + # Batch all images along the batch dimension, then encode them. + images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0) + images_batched = self.image_enc_layers(images_batched) + embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys)) + embeddings_image = torch.cat(embeddings_chunks, dim=-1) + return embeddings_image + return None + + @property + def output_dim(self) -> int: + """Returns the dimension of the encoder output""" + return self.config.latent_dim + + class MLP(nn.Module): def __init__( self, @@ -459,7 +719,7 @@ class CriticEnsemble(nn.Module): def __init__( self, - encoder: Optional[nn.Module], + encoder: SACObservationEncoder, ensemble: List[CriticHead], output_normalization: nn.Module, init_final: Optional[float] = None, @@ -491,11 +751,7 @@ class CriticEnsemble(nn.Module): actions = self.output_normalization(actions)["action"] actions = actions.to(device) - obs_enc = ( - observation_features - if observation_features is not None - else (observations if self.encoder is None else self.encoder(observations)) - ) + obs_enc = self.encoder(observations, observation_features) inputs = torch.cat([obs_enc, actions], dim=-1) @@ -509,10 +765,57 @@ class CriticEnsemble(nn.Module): return q_values +class GraspCritic(nn.Module): + def __init__( + self, + encoder: nn.Module, + input_dim: int, + hidden_dims: list[int], + output_dim: int = 3, + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: Optional[float] = None, + init_final: Optional[float] = None, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, + ): + super().__init__() + self.encoder = encoder + self.output_dim = output_dim + + self.net = MLP( + input_dim=input_dim, + hidden_dims=hidden_dims, + activations=activations, + activate_final=activate_final, + dropout_rate=dropout_rate, + final_activation=final_activation, + ) + + self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim) + if init_final is not None: + nn.init.uniform_(self.output_layer.weight, -init_final, init_final) + nn.init.uniform_(self.output_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.output_layer.weight) + + self.parameters_to_optimize = [] + self.parameters_to_optimize += list(self.net.parameters()) + self.parameters_to_optimize += list(self.output_layer.parameters()) + + def forward( + self, observations: torch.Tensor, observation_features: torch.Tensor | None = None + ) -> torch.Tensor: + device = get_device_from_parameters(self) + # Move each tensor in observations to device by cloning first to avoid inplace operations + observations = {k: v.to(device) for k, v in observations.items()} + obs_enc = self.encoder(observations, vision_encoder_cache=observation_features) + return self.output_layer(self.net(obs_enc)) + + class Policy(nn.Module): def __init__( self, - encoder: Optional[nn.Module], + encoder: SACObservationEncoder, network: nn.Module, action_dim: int, log_std_min: float = -5, @@ -523,7 +826,7 @@ class Policy(nn.Module): encoder_is_shared: bool = False, ): super().__init__() - self.encoder = encoder + self.encoder: SACObservationEncoder = encoder self.network = network self.action_dim = action_dim self.log_std_min = log_std_min @@ -566,11 +869,7 @@ class Policy(nn.Module): observation_features: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Encode observations if encoder exists - obs_enc = ( - observation_features - if observation_features is not None - else (observations if self.encoder is None else self.encoder(observations)) - ) + obs_enc = self.encoder(observations, vision_encoder_cache=observation_features) # Get network outputs outputs = self.network(obs_enc) @@ -614,96 +913,6 @@ class Policy(nn.Module): return observations -class SACObservationEncoder(nn.Module): - """Encode image and/or state vector observations.""" - - def __init__(self, config: SACConfig, input_normalizer: nn.Module): - """ - Creates encoders for pixel and/or state modalities. - """ - super().__init__() - self.config = config - self.input_normalization = input_normalizer - self.has_pretrained_vision_encoder = False - self.parameters_to_optimize = [] - - self.aggregation_size: int = 0 - if any("observation.image" in key for key in config.input_features): - self.camera_number = config.camera_number - - if self.config.vision_encoder_name is not None: - self.image_enc_layers = PretrainedImageEncoder(config) - self.has_pretrained_vision_encoder = True - else: - self.image_enc_layers = DefaultImageEncoder(config) - - self.aggregation_size += config.latent_dim * self.camera_number - - if config.freeze_vision_encoder: - freeze_image_encoder(self.image_enc_layers) - else: - self.parameters_to_optimize += list(self.image_enc_layers.parameters()) - self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")] - - if "observation.state" in config.input_features: - self.state_enc_layers = nn.Sequential( - nn.Linear( - in_features=config.input_features["observation.state"].shape[0], - out_features=config.latent_dim, - ), - nn.LayerNorm(normalized_shape=config.latent_dim), - nn.Tanh(), - ) - self.aggregation_size += config.latent_dim - - self.parameters_to_optimize += list(self.state_enc_layers.parameters()) - - if "observation.environment_state" in config.input_features: - self.env_state_enc_layers = nn.Sequential( - nn.Linear( - in_features=config.input_features["observation.environment_state"].shape[0], - out_features=config.latent_dim, - ), - nn.LayerNorm(normalized_shape=config.latent_dim), - nn.Tanh(), - ) - self.aggregation_size += config.latent_dim - self.parameters_to_optimize += list(self.env_state_enc_layers.parameters()) - - self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) - self.parameters_to_optimize += list(self.aggregation_layer.parameters()) - - def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: - """Encode the image and/or state vector. - - Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken - over all features. - """ - feat = [] - obs_dict = self.input_normalization(obs_dict) - # Batch all images along the batch dimension, then encode them. - if len(self.all_image_keys) > 0: - images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0) - images_batched = self.image_enc_layers(images_batched) - embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys)) - feat.extend(embeddings_chunks) - - if "observation.environment_state" in self.config.input_features: - feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) - if "observation.state" in self.config.input_features: - feat.append(self.state_enc_layers(obs_dict["observation.state"])) - - features = torch.cat(tensors=feat, dim=-1) - features = self.aggregation_layer(features) - - return features - - @property - def output_dim(self) -> int: - """Returns the dimension of the encoder output""" - return self.config.latent_dim - - class DefaultImageEncoder(nn.Module): def __init__(self, config: SACConfig): super().__init__() diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 776ad9ec..8db1a82c 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -15,7 +15,6 @@ # limitations under the License. import functools import io -import os import pickle from typing import Any, Callable, Optional, Sequence, TypedDict @@ -33,7 +32,7 @@ class Transition(TypedDict): next_state: dict[str, torch.Tensor] done: bool truncated: bool - complementary_info: dict[str, Any] = None + complementary_info: dict[str, torch.Tensor | float | int] | None = None class BatchTransition(TypedDict): @@ -43,41 +42,47 @@ class BatchTransition(TypedDict): next_state: dict[str, torch.Tensor] done: torch.Tensor truncated: torch.Tensor + complementary_info: dict[str, torch.Tensor | float | int] | None = None def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition: - # Move state tensors to CPU device = torch.device(device) + non_blocking = device.type == "cuda" + + # Move state tensors to device transition["state"] = { - key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items() + key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items() } - # Move action to CPU - transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda") + # Move action to device + transition["action"] = transition["action"].to(device, non_blocking=non_blocking) - # No need to move reward or done, as they are float and bool - - # No need to move reward or done, as they are float and bool + # Move reward and done if they are tensors if isinstance(transition["reward"], torch.Tensor): - transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda") + transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking) if isinstance(transition["done"], torch.Tensor): - transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda") + transition["done"] = transition["done"].to(device, non_blocking=non_blocking) if isinstance(transition["truncated"], torch.Tensor): - transition["truncated"] = transition["truncated"].to(device, non_blocking=device.type == "cuda") + transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking) - # Move next_state tensors to CPU + # Move next_state tensors to device transition["next_state"] = { - key: val.to(device, non_blocking=device.type == "cuda") - for key, val in transition["next_state"].items() + key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items() } - # If complementary_info is present, move its tensors to CPU - # if transition["complementary_info"] is not None: - # transition["complementary_info"] = { - # key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items() - # } + # Move complementary_info tensors if present + if transition.get("complementary_info") is not None: + for key, val in transition["complementary_info"].items(): + if isinstance(val, torch.Tensor): + transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking) + elif isinstance(val, (int, float, bool)): + transition["complementary_info"][key] = torch.tensor( + val, device=device, non_blocking=non_blocking + ) + else: + raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]") return transition @@ -217,7 +222,12 @@ class ReplayBuffer: self.image_augmentation_function = torch.compile(base_function) self.use_drq = use_drq - def _initialize_storage(self, state: dict[str, torch.Tensor], action: torch.Tensor): + def _initialize_storage( + self, + state: dict[str, torch.Tensor], + action: torch.Tensor, + complementary_info: Optional[dict[str, torch.Tensor]] = None, + ): """Initialize the storage tensors based on the first transition.""" # Determine shapes from the first transition state_shapes = {key: val.squeeze(0).shape for key, val in state.items()} @@ -244,6 +254,27 @@ class ReplayBuffer: self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) + + # Initialize storage for complementary_info + self.has_complementary_info = complementary_info is not None + self.complementary_info_keys = [] + self.complementary_info = {} + + if self.has_complementary_info: + self.complementary_info_keys = list(complementary_info.keys()) + # Pre-allocate tensors for each key in complementary_info + for key, value in complementary_info.items(): + if isinstance(value, torch.Tensor): + value_shape = value.squeeze(0).shape + self.complementary_info[key] = torch.empty( + (self.capacity, *value_shape), device=self.storage_device + ) + elif isinstance(value, (int, float)): + # Handle scalar values similar to reward + self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device) + else: + raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]") + self.initialized = True def __len__(self): @@ -262,7 +293,7 @@ class ReplayBuffer: """Saves a transition, ensuring tensors are stored on the designated storage device.""" # Initialize storage if this is the first transition if not self.initialized: - self._initialize_storage(state=state, action=action) + self._initialize_storage(state=state, action=action, complementary_info=complementary_info) # Store the transition in pre-allocated tensors for key in self.states: @@ -277,6 +308,17 @@ class ReplayBuffer: self.dones[self.position] = done self.truncateds[self.position] = truncated + # Handle complementary_info if provided and storage is initialized + if complementary_info is not None and self.has_complementary_info: + # Store the complementary_info + for key in self.complementary_info_keys: + if key in complementary_info: + value = complementary_info[key] + if isinstance(value, torch.Tensor): + self.complementary_info[key][self.position].copy_(value.squeeze(dim=0)) + elif isinstance(value, (int, float)): + self.complementary_info[key][self.position] = value + self.position = (self.position + 1) % self.capacity self.size = min(self.size + 1, self.capacity) @@ -335,6 +377,13 @@ class ReplayBuffer: batch_dones = self.dones[idx].to(self.device).float() batch_truncateds = self.truncateds[idx].to(self.device).float() + # Sample complementary_info if available + batch_complementary_info = None + if self.has_complementary_info: + batch_complementary_info = {} + for key in self.complementary_info_keys: + batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device) + return BatchTransition( state=batch_state, action=batch_actions, @@ -342,8 +391,112 @@ class ReplayBuffer: next_state=batch_next_state, done=batch_dones, truncated=batch_truncateds, + complementary_info=batch_complementary_info, ) + def get_iterator( + self, + batch_size: int, + async_prefetch: bool = True, + queue_size: int = 2, + ): + """ + Creates an infinite iterator that yields batches of transitions. + Will automatically restart when internal iterator is exhausted. + + Args: + batch_size (int): Size of batches to sample + async_prefetch (bool): Whether to use asynchronous prefetching with threads (default: True) + queue_size (int): Number of batches to prefetch (default: 2) + + Yields: + BatchTransition: Batched transitions + """ + while True: # Create an infinite loop + if async_prefetch: + # Get the standard iterator + iterator = self._get_async_iterator(queue_size=queue_size, batch_size=batch_size) + else: + iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size) + + # Yield all items from the iterator + try: + yield from iterator + except StopIteration: + # Just continue the outer loop to create a new iterator + pass + + def _get_async_iterator(self, batch_size: int, queue_size: int = 2): + """ + Creates an iterator that prefetches batches in a background thread. + + Args: + queue_size (int): Number of batches to prefetch (default: 2) + batch_size (int): Size of batches to sample (default: 128) + + Yields: + BatchTransition: Prefetched batch transitions + """ + import queue + import threading + + # Use thread-safe queue + data_queue = queue.Queue(maxsize=queue_size) + running = [True] # Use list to allow modification in nested function + + def prefetch_worker(): + while running[0]: + try: + # Sample data and add to queue + data = self.sample(batch_size) + data_queue.put(data, block=True, timeout=0.5) + except queue.Full: + continue + except Exception as e: + print(f"Prefetch error: {e}") + break + + # Start prefetching thread + thread = threading.Thread(target=prefetch_worker, daemon=True) + thread.start() + + try: + while running[0]: + try: + yield data_queue.get(block=True, timeout=0.5) + except queue.Empty: + if not thread.is_alive(): + break + finally: + # Clean up + running[0] = False + thread.join(timeout=1.0) + + def _get_naive_iterator(self, batch_size: int, queue_size: int = 2): + """ + Creates a simple non-threaded iterator that yields batches. + + Args: + batch_size (int): Size of batches to sample + queue_size (int): Number of initial batches to prefetch + + Yields: + BatchTransition: Batch transitions + """ + import collections + + queue = collections.deque() + + def enqueue(n): + for _ in range(n): + data = self.sample(batch_size) + queue.append(data) + + enqueue(queue_size) + while queue: + yield queue.popleft() + enqueue(1) + @classmethod def from_lerobot_dataset( cls, @@ -415,7 +568,19 @@ class ReplayBuffer: if action_delta is not None: first_action = first_action / action_delta - replay_buffer._initialize_storage(state=first_state, action=first_action) + # Get complementary info if available + first_complementary_info = None + if ( + "complementary_info" in first_transition + and first_transition["complementary_info"] is not None + ): + first_complementary_info = { + k: v.to(device) for k, v in first_transition["complementary_info"].items() + } + + replay_buffer._initialize_storage( + state=first_state, action=first_action, complementary_info=first_complementary_info + ) # Fill the buffer with all transitions for data in list_transition: @@ -443,6 +608,7 @@ class ReplayBuffer: next_state=data["next_state"], done=data["done"], truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset + complementary_info=data.get("complementary_info", None), ) return replay_buffer @@ -484,6 +650,15 @@ class ReplayBuffer: f_info = guess_feature_info(t=sample_val, name=key) features[key] = f_info + # Add complementary_info keys if available + if self.has_complementary_info: + for key in self.complementary_info_keys: + sample_val = self.complementary_info[key][0] + if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0: + sample_val = sample_val.unsqueeze(0) + f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}") + features[f"complementary_info.{key}"] = f_info + # Create an empty LeRobotDataset lerobot_dataset = LeRobotDataset.create( repo_id=repo_id, @@ -517,6 +692,19 @@ class ReplayBuffer: frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() + # Add complementary_info if available + if self.has_complementary_info: + for key in self.complementary_info_keys: + val = self.complementary_info[key][actual_idx] + # Convert tensors to CPU + if isinstance(val, torch.Tensor): + if val.ndim == 0: + val = val.unsqueeze(0) + frame_dict[f"complementary_info.{key}"] = val.cpu() + # Non-tensor values can be used directly + else: + frame_dict[f"complementary_info.{key}"] = val + # Add task field which is required by LeRobotDataset frame_dict["task"] = task_name @@ -583,6 +771,10 @@ class ReplayBuffer: sample = dataset[0] has_done_key = "next.done" in sample + # Check for complementary_info keys + complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")] + has_complementary_info = len(complementary_info_keys) > 0 + # If not, we need to infer it from episode boundaries if not has_done_key: print("'next.done' key not found in dataset. Inferring from episode boundaries...") @@ -632,6 +824,22 @@ class ReplayBuffer: next_state_data[key] = val.unsqueeze(0) # Add batch dimension next_state = next_state_data + # ----- 5) Complementary info (if available) ----- + complementary_info = None + if has_complementary_info: + complementary_info = {} + for key in complementary_info_keys: + # Strip the "complementary_info." prefix to get the actual key + clean_key = key[len("complementary_info.") :] + val = current_sample[key] + # Handle tensor and non-tensor values differently + if isinstance(val, torch.Tensor): + complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension + else: + # TODO: (azouitine) Check if it's necessary to convert to tensor + # For non-tensor values, use directly + complementary_info[clean_key] = val + # ----- Construct the Transition ----- transition = Transition( state=current_state, @@ -640,6 +848,7 @@ class ReplayBuffer: next_state=next_state, done=done, truncated=truncated, + complementary_info=complementary_info, ) transitions.append(transition) @@ -647,12 +856,13 @@ class ReplayBuffer: # Utility function to guess shapes/dtypes from a tensor -def guess_feature_info(t: torch.Tensor, name: str): +def guess_feature_info(t, name: str): """ - Return a dictionary with the 'dtype' and 'shape' for a given tensor or array. + Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value. If it looks like a 3D (C,H,W) shape, we might consider it an 'image'. - Otherwise default to 'float32' for numeric. You can customize as needed. + Otherwise default to appropriate dtype for numeric. """ + shape = tuple(t.shape) # Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image' if len(shape) == 3 and shape[0] in [1, 3]: @@ -672,32 +882,33 @@ def concatenate_batch_transitions( left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition ) -> BatchTransition: """NOTE: Be careful it change the left_batch_transitions in place""" + # Concatenate state fields left_batch_transitions["state"] = { key: torch.cat( - [ - left_batch_transitions["state"][key], - right_batch_transition["state"][key], - ], + [left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0, ) for key in left_batch_transitions["state"] } + + # Concatenate basic fields left_batch_transitions["action"] = torch.cat( [left_batch_transitions["action"], right_batch_transition["action"]], dim=0 ) left_batch_transitions["reward"] = torch.cat( [left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0 ) + + # Concatenate next_state fields left_batch_transitions["next_state"] = { key: torch.cat( - [ - left_batch_transitions["next_state"][key], - right_batch_transition["next_state"][key], - ], + [left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0, ) for key in left_batch_transitions["next_state"] } + + # Concatenate done and truncated fields left_batch_transitions["done"] = torch.cat( [left_batch_transitions["done"], right_batch_transition["done"]], dim=0 ) @@ -705,479 +916,114 @@ def concatenate_batch_transitions( [left_batch_transitions["truncated"], right_batch_transition["truncated"]], dim=0, ) + + # Handle complementary_info + left_info = left_batch_transitions.get("complementary_info") + right_info = right_batch_transition.get("complementary_info") + + # Only process if right_info exists + if right_info is not None: + # Initialize left complementary_info if needed + if left_info is None: + left_batch_transitions["complementary_info"] = right_info + else: + # Concatenate each field + for key in right_info: + if key in left_info: + left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0) + else: + left_info[key] = right_info[key] + return left_batch_transitions if __name__ == "__main__": - from tempfile import TemporaryDirectory - # ===== Test 1: Create and use a synthetic ReplayBuffer ===== - print("Testing synthetic ReplayBuffer...") + def test_load_dataset_with_complementary_info(): + """ + Test loading a dataset with complementary_info into a ReplayBuffer. + The dataset 'aractingi/pick_lift_cube_two_cameras_gripper_penalty' contains + gripper_penalty values in complementary_info. + """ + import time - # Create sample data dimensions - batch_size = 32 - state_dims = {"observation.image": (3, 84, 84), "observation.state": (10,)} - action_dim = (6,) + from lerobot.common.datasets.lerobot_dataset import LeRobotDataset - # Create a buffer - buffer = ReplayBuffer( - capacity=1000, - device="cpu", - state_keys=list(state_dims.keys()), - use_drq=True, - storage_device="cpu", - ) - - # Add some random transitions - for i in range(100): - # Create dummy transition data - state = { - "observation.image": torch.rand(1, 3, 84, 84), - "observation.state": torch.rand(1, 10), - } - action = torch.rand(1, 6) - reward = 0.5 - next_state = { - "observation.image": torch.rand(1, 3, 84, 84), - "observation.state": torch.rand(1, 10), - } - done = False if i < 99 else True - truncated = False - - buffer.add( - state=state, - action=action, - reward=reward, - next_state=next_state, - done=done, - truncated=truncated, + print("Loading dataset with complementary info...") + # Load a small subset of the dataset (first episode) + dataset = LeRobotDataset( + repo_id="aractingi/pick_lift_cube_two_cameras_gripper_penalty", ) - # Test sampling - batch = buffer.sample(batch_size) - print(f"Buffer size: {len(buffer)}") - print( - f"Sampled batch state shapes: {batch['state']['observation.image'].shape}, {batch['state']['observation.state'].shape}" - ) - print(f"Sampled batch action shape: {batch['action'].shape}") - print(f"Sampled batch reward shape: {batch['reward'].shape}") - print(f"Sampled batch done shape: {batch['done'].shape}") - print(f"Sampled batch truncated shape: {batch['truncated'].shape}") - - # ===== Test for state-action-reward alignment ===== - print("\nTesting state-action-reward alignment...") - - # Create a buffer with controlled transitions where we know the relationships - aligned_buffer = ReplayBuffer( - capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu" - ) - - # Create transitions with known relationships - # - Each state has a unique signature value - # - Action is 2x the state signature - # - Reward is 3x the state signature - # - Next state is signature + 0.01 (unless at episode end) - for i in range(100): - # Create a state with a signature value that encodes the transition number - signature = float(i) / 100.0 - state = {"state_value": torch.tensor([[signature]]).float()} - - # Action is 2x the signature - action = torch.tensor([[2.0 * signature]]).float() - - # Reward is 3x the signature - reward = 3.0 * signature - - # Next state is signature + 0.01, unless end of episode - # End episode every 10 steps - is_end = (i + 1) % 10 == 0 - - if is_end: - # At episode boundaries, next_state repeats current state (as per your implementation) - next_state = {"state_value": torch.tensor([[signature]]).float()} - done = True - else: - # Within episodes, next_state has signature + 0.01 - next_signature = float(i + 1) / 100.0 - next_state = {"state_value": torch.tensor([[next_signature]]).float()} - done = False - - aligned_buffer.add(state, action, reward, next_state, done, False) - - # Sample from this buffer - aligned_batch = aligned_buffer.sample(50) - - # Verify alignments in sampled batch - correct_relationships = 0 - total_checks = 0 - - # For each transition in the batch - for i in range(50): - # Extract signature from state - state_sig = aligned_batch["state"]["state_value"][i].item() - - # Check action is 2x signature (within reasonable precision) - action_val = aligned_batch["action"][i].item() - action_check = abs(action_val - 2.0 * state_sig) < 1e-4 - - # Check reward is 3x signature (within reasonable precision) - reward_val = aligned_batch["reward"][i].item() - reward_check = abs(reward_val - 3.0 * state_sig) < 1e-4 - - # Check next_state relationship matches our pattern - next_state_sig = aligned_batch["next_state"]["state_value"][i].item() - is_done = aligned_batch["done"][i].item() > 0.5 - - # Calculate expected next_state value based on done flag - if is_done: - # For episodes that end, next_state should equal state - next_state_check = abs(next_state_sig - state_sig) < 1e-4 - else: - # For continuing episodes, check if next_state is approximately state + 0.01 - # We need to be careful because we don't know the original index - # So we check if the increment is roughly 0.01 - next_state_check = ( - abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4 - ) - - # Count correct relationships - if action_check: - correct_relationships += 1 - if reward_check: - correct_relationships += 1 - if next_state_check: - correct_relationships += 1 - - total_checks += 3 - - alignment_accuracy = 100.0 * correct_relationships / total_checks - print(f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%") - if alignment_accuracy > 99.0: - print("✅ All relationships verified! Buffer maintains correct temporal relationships.") - else: - print("⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues.") - - # Print some debug information about failures - print("\nDebug information for failed checks:") - for i in range(5): # Print first 5 transitions for debugging - state_sig = aligned_batch["state"]["state_value"][i].item() - action_val = aligned_batch["action"][i].item() - reward_val = aligned_batch["reward"][i].item() - next_state_sig = aligned_batch["next_state"]["state_value"][i].item() - is_done = aligned_batch["done"][i].item() > 0.5 - - print(f"Transition {i}:") - print(f" State: {state_sig:.6f}") - print(f" Action: {action_val:.6f} (expected: {2.0 * state_sig:.6f})") - print(f" Reward: {reward_val:.6f} (expected: {3.0 * state_sig:.6f})") - print(f" Done: {is_done}") - print(f" Next state: {next_state_sig:.6f}") - - # Calculate expected next state - if is_done: - expected_next = state_sig - else: - # This approximation might not be perfect - state_idx = round(state_sig * 100) - expected_next = (state_idx + 1) / 100.0 - - print(f" Expected next state: {expected_next:.6f}") - print() - - # ===== Test 2: Convert to LeRobotDataset and back ===== - with TemporaryDirectory() as temp_dir: - print("\nTesting conversion to LeRobotDataset and back...") - # Convert buffer to dataset - repo_id = "test/replay_buffer_conversion" - # Create a subdirectory to avoid the "directory exists" error - dataset_dir = os.path.join(temp_dir, "dataset1") - dataset = buffer.to_lerobot_dataset(repo_id=repo_id, root=dataset_dir) - - print(f"Dataset created with {len(dataset)} frames") + print(f"Dataset loaded with {len(dataset)} frames") print(f"Dataset features: {list(dataset.features.keys())}") - # Check a random sample from the dataset + # Check if dataset has complementary_info.gripper_penalty sample = dataset[0] - print( - f"Dataset sample types: {[(k, type(v)) for k, v in sample.items() if k.startswith('observation')]}" + complementary_info_keys = [key for key in sample if key.startswith("complementary_info")] + print(f"Complementary info keys: {complementary_info_keys}") + + if "complementary_info.gripper_penalty" in sample: + print(f"Found gripper_penalty: {sample['complementary_info.gripper_penalty']}") + + # Extract state keys for the buffer + state_keys = [] + for key in sample: + if key.startswith("observation"): + state_keys.append(key) + + print(f"Using state keys: {state_keys}") + + # Create a replay buffer from the dataset + start_time = time.time() + buffer = ReplayBuffer.from_lerobot_dataset( + lerobot_dataset=dataset, state_keys=state_keys, use_drq=True, optimize_memory=False ) + load_time = time.time() - start_time + print(f"Loaded dataset into buffer in {load_time:.2f} seconds") + print(f"Buffer size: {len(buffer)}") - # Convert dataset back to buffer - reconverted_buffer = ReplayBuffer.from_lerobot_dataset( - dataset, state_keys=list(state_dims.keys()), device="cpu" - ) + # Check if complementary_info was transferred correctly + print("Sampling from buffer to check complementary_info...") + batch = buffer.sample(batch_size=4) - print(f"Reconverted buffer size: {len(reconverted_buffer)}") - - # Sample from the reconverted buffer - reconverted_batch = reconverted_buffer.sample(batch_size) - print( - f"Reconverted batch state shapes: {reconverted_batch['state']['observation.image'].shape}, {reconverted_batch['state']['observation.state'].shape}" - ) - - # Verify consistency before and after conversion - original_states = batch["state"]["observation.image"].mean().item() - reconverted_states = reconverted_batch["state"]["observation.image"].mean().item() - print(f"Original buffer state mean: {original_states:.4f}") - print(f"Reconverted buffer state mean: {reconverted_states:.4f}") - - if abs(original_states - reconverted_states) < 1.0: - print("Values are reasonably similar - conversion works as expected") + if batch["complementary_info"] is not None: + print("Complementary info in batch:") + for key, value in batch["complementary_info"].items(): + print(f" {key}: {type(value)}, shape: {value.shape if hasattr(value, 'shape') else 'N/A'}") + if key == "gripper_penalty": + print(f" Sample gripper_penalty values: {value[:5]}") else: - print("WARNING: Significant difference between original and reconverted values") + print("No complementary_info found in batch") - print("\nAll previous tests completed!") - - # ===== Test for memory optimization ===== - print("\n===== Testing Memory Optimization =====") - - # Create two buffers, one with memory optimization and one without - standard_buffer = ReplayBuffer( - capacity=1000, - device="cpu", - state_keys=["observation.image", "observation.state"], - storage_device="cpu", - optimize_memory=False, - use_drq=True, - ) - - optimized_buffer = ReplayBuffer( - capacity=1000, - device="cpu", - state_keys=["observation.image", "observation.state"], - storage_device="cpu", - optimize_memory=True, - use_drq=True, - ) - - # Generate sample data with larger state dimensions for better memory impact - print("Generating test data...") - num_episodes = 10 - steps_per_episode = 50 - total_steps = num_episodes * steps_per_episode - - for episode in range(num_episodes): - for step in range(steps_per_episode): - # Index in the overall sequence - i = episode * steps_per_episode + step - - # Create state with identifiable values - img = torch.ones((3, 84, 84)) * (i / total_steps) - state_vec = torch.ones((10,)) * (i / total_steps) - - state = { - "observation.image": img.unsqueeze(0), - "observation.state": state_vec.unsqueeze(0), - } - - # Create next state (i+1 or same as current if last in episode) - is_last_step = step == steps_per_episode - 1 - - if is_last_step: - # At episode end, next state = current state - next_img = img.clone() - next_state_vec = state_vec.clone() - done = True - truncated = False - else: - # Within episode, next state has incremented value - next_val = (i + 1) / total_steps - next_img = torch.ones((3, 84, 84)) * next_val - next_state_vec = torch.ones((10,)) * next_val - done = False - truncated = False - - next_state = { - "observation.image": next_img.unsqueeze(0), - "observation.state": next_state_vec.unsqueeze(0), - } - - # Action and reward - action = torch.tensor([[i / total_steps]]) - reward = float(i / total_steps) - - # Add to both buffers - standard_buffer.add(state, action, reward, next_state, done, truncated) - optimized_buffer.add(state, action, reward, next_state, done, truncated) - - # Verify episode boundaries with our simplified approach - print("\nVerifying simplified memory optimization...") - - # Test with a new buffer with a small sequence - test_buffer = ReplayBuffer( - capacity=20, - device="cpu", - state_keys=["value"], - storage_device="cpu", - optimize_memory=True, - use_drq=False, - ) - - # Add a simple sequence with known episode boundaries - for i in range(20): - val = float(i) - state = {"value": torch.tensor([[val]]).float()} - next_val = float(i + 1) if i % 5 != 4 else val # Episode ends every 5 steps - next_state = {"value": torch.tensor([[next_val]]).float()} - - # Set done=True at every 5th step - done = (i % 5) == 4 - action = torch.tensor([[0.0]]) - reward = 1.0 - truncated = False - - test_buffer.add(state, action, reward, next_state, done, truncated) - - # Get sequential batch for verification - sequential_batch_size = test_buffer.size - all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device) - - # Get state tensors - batch_state = {"value": test_buffer.states["value"][all_indices].to(test_buffer.device)} - - # Get next_state using memory-optimized approach (simply index+1) - next_indices = (all_indices + 1) % test_buffer.capacity - batch_next_state = {"value": test_buffer.states["value"][next_indices].to(test_buffer.device)} - - # Get other tensors - batch_dones = test_buffer.dones[all_indices].to(test_buffer.device) - - # Print sequential values - print("State, Next State, Done (Sequential values with simplified optimization):") - state_values = batch_state["value"].squeeze().tolist() - next_values = batch_next_state["value"].squeeze().tolist() - done_flags = batch_dones.tolist() - - # Print all values - for i in range(len(state_values)): - print(f" {state_values[i]:.1f} → {next_values[i]:.1f}, Done: {done_flags[i]}") - - # Explain the memory optimization tradeoff - print("\nWith simplified memory optimization:") - print("- We always use the next state in the buffer (index+1) as next_state") - print("- For terminal states, this means using the first state of the next episode") - print("- This is a common tradeoff in RL implementations for memory efficiency") - print("- Since we track done flags, the algorithm can handle these transitions correctly") - - # Test random sampling - print("\nVerifying random sampling with simplified memory optimization...") - random_samples = test_buffer.sample(20) # Sample all transitions - - # Extract values - random_state_values = random_samples["state"]["value"].squeeze().tolist() - random_next_values = random_samples["next_state"]["value"].squeeze().tolist() - random_done_flags = random_samples["done"].bool().tolist() - - # Print a few samples - print("Random samples - State, Next State, Done (First 10):") - for i in range(10): - print(f" {random_state_values[i]:.1f} → {random_next_values[i]:.1f}, Done: {random_done_flags[i]}") - - # Calculate memory savings - # Assume optimized_buffer and standard_buffer have already been initialized and filled - std_mem = ( - sum( - standard_buffer.states[key].nelement() * standard_buffer.states[key].element_size() - for key in standard_buffer.states + # Now convert the buffer back to a LeRobotDataset + print("\nConverting buffer back to LeRobotDataset...") + start_time = time.time() + new_dataset = buffer.to_lerobot_dataset( + repo_id="test_dataset_from_buffer", + fps=dataset.fps, + root="./test_dataset_from_buffer", + task_name="test_conversion", ) - * 2 - ) - opt_mem = sum( - optimized_buffer.states[key].nelement() * optimized_buffer.states[key].element_size() - for key in optimized_buffer.states - ) + convert_time = time.time() - start_time + print(f"Converted buffer to dataset in {convert_time:.2f} seconds") + print(f"New dataset size: {len(new_dataset)} frames") - savings_percent = (std_mem - opt_mem) / std_mem * 100 + # Check if complementary_info was preserved + new_sample = new_dataset[0] + new_complementary_info_keys = [key for key in new_sample if key.startswith("complementary_info")] + print(f"New dataset complementary info keys: {new_complementary_info_keys}") - print("\nMemory optimization result:") - print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB") - print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB") - print(f"- Memory savings for state tensors: {savings_percent:.1f}%") + if "complementary_info.gripper_penalty" in new_sample: + print(f"Found gripper_penalty in new dataset: {new_sample['complementary_info.gripper_penalty']}") - print("\nAll memory optimization tests completed!") + # Compare original and new datasets + print("\nComparing original and new datasets:") + print(f"Original dataset frames: {len(dataset)}, New dataset frames: {len(new_dataset)}") + print(f"Original features: {list(dataset.features.keys())}") + print(f"New features: {list(new_dataset.features.keys())}") - # # ===== Test real dataset conversion ===== - # print("\n===== Testing Real LeRobotDataset Conversion =====") - # try: - # # Try to use a real dataset if available - # dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small" - # dataset = LeRobotDataset(repo_id=dataset_name) + return buffer, dataset, new_dataset - # # Print available keys to debug - # sample = dataset[0] - # print("Available keys in dataset:", list(sample.keys())) - - # # Check for required keys - # if "action" not in sample or "next.reward" not in sample: - # print("Dataset missing essential keys. Cannot convert.") - # raise ValueError("Missing required keys in dataset") - - # # Auto-detect appropriate state keys - # image_keys = [] - # state_keys = [] - # for k, v in sample.items(): - # # Skip metadata keys and action/reward keys - # if k in { - # "index", - # "episode_index", - # "frame_index", - # "timestamp", - # "task_index", - # "action", - # "next.reward", - # "next.done", - # }: - # continue - - # # Infer key type from tensor shape - # if isinstance(v, torch.Tensor): - # if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1): - # # Likely an image (channels, height, width) - # image_keys.append(k) - # else: - # # Likely state or other vector - # state_keys.append(k) - - # print(f"Detected image keys: {image_keys}") - # print(f"Detected state keys: {state_keys}") - - # if not image_keys and not state_keys: - # print("No usable keys found in dataset, skipping further tests") - # raise ValueError("No usable keys found in dataset") - - # # Test with standard and memory-optimized buffers - # for optimize_memory in [False, True]: - # buffer_type = "Standard" if not optimize_memory else "Memory-optimized" - # print(f"\nTesting {buffer_type} buffer with real dataset...") - - # # Convert to ReplayBuffer with detected keys - # replay_buffer = ReplayBuffer.from_lerobot_dataset( - # lerobot_dataset=dataset, - # state_keys=image_keys + state_keys, - # device="cpu", - # optimize_memory=optimize_memory, - # ) - # print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}") - - # # Test sampling - # real_batch = replay_buffer.sample(32) - # print(f"Sampled batch from real dataset ({buffer_type}), state shapes:") - # for key in real_batch["state"]: - # print(f" {key}: {real_batch['state'][key].shape}") - - # # Convert back to LeRobotDataset - # with TemporaryDirectory() as temp_dir: - # dataset_name = f"test/real_dataset_converted_{buffer_type}" - # replay_buffer_converted = replay_buffer.to_lerobot_dataset( - # repo_id=dataset_name, - # root=os.path.join(temp_dir, f"dataset_{buffer_type}"), - # ) - # print( - # f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames" - # ) - - # except Exception as e: - # print(f"Real dataset test failed: {e}") - # print("This is expected if running offline or if the dataset is not available.") - - # print("\nAll tests completed!") + # Run the test + test_load_dataset_with_complementary_info() diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 92e8dcbc..3aa75466 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -761,6 +761,62 @@ class BatchCompitableWrapper(gym.ObservationWrapper): return observation +class GripperPenaltyWrapper(gym.RewardWrapper): + def __init__(self, env, penalty: float = -0.1): + super().__init__(env) + self.penalty = penalty + self.last_gripper_state = None + + def reward(self, reward, action): + gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND + + if isinstance(action, tuple): + action = action[0] + action_normalized = action[-1] / MAX_GRIPPER_COMMAND + + gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or ( + gripper_state_normalized > 0.9 and action_normalized < 0.1 + ) + breakpoint() + + return reward + self.penalty * gripper_penalty_bool + + def step(self, action): + self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] + obs, reward, terminated, truncated, info = self.env.step(action) + reward = self.reward(reward, action) + return obs, reward, terminated, truncated, info + + def reset(self, **kwargs): + self.last_gripper_state = None + return super().reset(**kwargs) + + +class GripperQuantizationWrapper(gym.ActionWrapper): + def __init__(self, env, quantization_threshold: float = 0.2): + super().__init__(env) + self.quantization_threshold = quantization_threshold + + def action(self, action): + is_intervention = False + if isinstance(action, tuple): + action, is_intervention = action + + gripper_command = action[-1] + # Quantize gripper command to -1, 0 or 1 + if gripper_command < -self.quantization_threshold: + gripper_command = -MAX_GRIPPER_COMMAND + elif gripper_command > self.quantization_threshold: + gripper_command = MAX_GRIPPER_COMMAND + else: + gripper_command = 0.0 + + gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] + gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND) + action[-1] = gripper_action.item() + return action, is_intervention + + class EEActionWrapper(gym.ActionWrapper): def __init__(self, env, ee_action_space_params=None, use_gripper=False): super().__init__(env) @@ -820,17 +876,7 @@ class EEActionWrapper(gym.ActionWrapper): fk_func=self.fk_function, ) if self.use_gripper: - # Quantize gripper command to -1, 0 or 1 - if gripper_command < -0.2: - gripper_command = -1.0 - elif gripper_command > 0.2: - gripper_command = 1.0 - else: - gripper_command = 0.0 - - gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] - gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND) - target_joint_pos[-1] = gripper_action + target_joint_pos[-1] = gripper_command return target_joint_pos, is_intervention @@ -1118,6 +1164,12 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: # Add reward computation and control wrappers # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) + if cfg.wrapper.use_gripper: + env = GripperQuantizationWrapper( + env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold + ) + # env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty) + if cfg.wrapper.ee_action_space_params is not None: env = EEActionWrapper( env=env, diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 98d2dbd8..5489d6dc 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +# !/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. # All rights reserved. @@ -269,6 +269,7 @@ def add_actor_information_and_train( policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency saving_checkpoint = cfg.save_checkpoint online_steps = cfg.policy.online_steps + async_prefetch = cfg.policy.async_prefetch # Initialize logging for multiprocessing if not use_threads(cfg): @@ -326,6 +327,9 @@ def add_actor_information_and_train( if cfg.dataset is not None: dataset_repo_id = cfg.dataset.repo_id + # Initialize iterators + online_iterator = None + offline_iterator = None # NOTE: THIS IS THE MAIN LOOP OF THE LEARNER while True: # Exit the training loop if shutdown is requested @@ -359,13 +363,26 @@ def add_actor_information_and_train( if len(replay_buffer) < online_step_before_learning: continue + if online_iterator is None: + logging.debug("[LEARNER] Initializing online replay buffer iterator") + online_iterator = replay_buffer.get_iterator( + batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 + ) + + if offline_replay_buffer is not None and offline_iterator is None: + logging.debug("[LEARNER] Initializing offline replay buffer iterator") + offline_iterator = offline_replay_buffer.get_iterator( + batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 + ) + logging.debug("[LEARNER] Starting optimization loop") time_for_one_optimization_step = time.time() for _ in range(utd_ratio - 1): - batch = replay_buffer.sample(batch_size=batch_size) + # Sample from the iterators + batch = next(online_iterator) if dataset_repo_id is not None: - batch_offline = offline_replay_buffer.sample(batch_size=batch_size) + batch_offline = next(offline_iterator) batch = concatenate_batch_transitions( left_batch_transitions=batch, right_batch_transition=batch_offline ) @@ -392,24 +409,37 @@ def add_actor_information_and_train( "next_observation_feature": next_observation_features, } - # Use the forward method for critic loss - loss_critic = policy.forward(forward_batch, model="critic") + # Use the forward method for critic loss (includes both main critic and grasp critic) + critic_output = policy.forward(forward_batch, model="critic") + + # Main critic optimization + loss_critic = critic_output["loss_critic"] optimizers["critic"].zero_grad() loss_critic.backward() - - # clip gradients critic_grad_norm = torch.nn.utils.clip_grad_norm_( parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value ) - optimizers["critic"].step() + # Grasp critic optimization (if available) + if policy.config.num_discrete_actions is not None: + discrete_critic_output = policy.forward(forward_batch, model="grasp_critic") + loss_grasp_critic = discrete_critic_output["loss_grasp_critic"] + optimizers["grasp_critic"].zero_grad() + loss_grasp_critic.backward() + grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value + ) + optimizers["grasp_critic"].step() + + # Update target networks policy.update_target_networks() - batch = replay_buffer.sample(batch_size=batch_size) + # Sample for the last update in the UTD ratio + batch = next(online_iterator) if dataset_repo_id is not None: - batch_offline = offline_replay_buffer.sample(batch_size=batch_size) + batch_offline = next(offline_iterator) batch = concatenate_batch_transitions( left_batch_transitions=batch, right_batch_transition=batch_offline ) @@ -437,63 +467,80 @@ def add_actor_information_and_train( "next_observation_feature": next_observation_features, } - # Use the forward method for critic loss - loss_critic = policy.forward(forward_batch, model="critic") + # Use the forward method for critic loss (includes both main critic and grasp critic) + critic_output = policy.forward(forward_batch, model="critic") + + # Main critic optimization + loss_critic = critic_output["loss_critic"] optimizers["critic"].zero_grad() loss_critic.backward() - - # clip gradients critic_grad_norm = torch.nn.utils.clip_grad_norm_( parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value ).item() - optimizers["critic"].step() - training_infos = {} - training_infos["loss_critic"] = loss_critic.item() - training_infos["critic_grad_norm"] = critic_grad_norm + # Initialize training info dictionary + training_infos = { + "loss_critic": loss_critic.item(), + "critic_grad_norm": critic_grad_norm, + } + # Grasp critic optimization (if available) + if policy.config.num_discrete_actions is not None: + discrete_critic_output = policy.forward(forward_batch, model="grasp_critic") + loss_grasp_critic = discrete_critic_output["loss_grasp_critic"] + optimizers["grasp_critic"].zero_grad() + loss_grasp_critic.backward() + grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value + ).item() + optimizers["grasp_critic"].step() + + # Add grasp critic info to training info + training_infos["loss_grasp_critic"] = loss_grasp_critic.item() + training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm + + # Actor and temperature optimization (at specified frequency) if optimization_step % policy_update_freq == 0: for _ in range(policy_update_freq): - # Use the forward method for actor loss - loss_actor = policy.forward(forward_batch, model="actor") - + # Actor optimization + actor_output = policy.forward(forward_batch, model="actor") + loss_actor = actor_output["loss_actor"] optimizers["actor"].zero_grad() loss_actor.backward() - - # clip gradients actor_grad_norm = torch.nn.utils.clip_grad_norm_( parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value ).item() - optimizers["actor"].step() + # Add actor info to training info training_infos["loss_actor"] = loss_actor.item() training_infos["actor_grad_norm"] = actor_grad_norm - # Temperature optimization using forward method - loss_temperature = policy.forward(forward_batch, model="temperature") + # Temperature optimization + temperature_output = policy.forward(forward_batch, model="temperature") + loss_temperature = temperature_output["loss_temperature"] optimizers["temperature"].zero_grad() loss_temperature.backward() - - # clip gradients temp_grad_norm = torch.nn.utils.clip_grad_norm_( parameters=[policy.log_alpha], max_norm=clip_grad_norm_value ).item() - optimizers["temperature"].step() + # Add temperature info to training info training_infos["loss_temperature"] = loss_temperature.item() training_infos["temperature_grad_norm"] = temp_grad_norm training_infos["temperature"] = policy.temperature + # Update temperature policy.update_temperature() - # Check if it's time to push updated policy to actors + # Push policy to actors if needed if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) last_time_policy_pushed = time.time() + # Update target networks policy.update_target_networks() # Log training metrics at specified intervals @@ -697,7 +744,7 @@ def save_training_checkpoint( logging.info("Resume training") -def make_optimizers_and_scheduler(cfg, policy: nn.Module): +def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module): """ Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy. @@ -728,7 +775,14 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module): params=policy.actor.parameters_to_optimize, lr=cfg.policy.actor_lr, ) - optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) + optimizer_critic = torch.optim.Adam( + params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr + ) + + if cfg.policy.num_discrete_actions is not None: + optimizer_grasp_critic = torch.optim.Adam( + params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr + ) optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) lr_scheduler = None optimizers = { @@ -736,6 +790,8 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module): "critic": optimizer_critic, "temperature": optimizer_temperature, } + if cfg.policy.num_discrete_actions is not None: + optimizers["grasp_critic"] = optimizer_grasp_critic return optimizers, lr_scheduler @@ -970,12 +1026,8 @@ def get_observation_features( return None, None with torch.no_grad(): - observation_features = ( - policy.actor.encoder(observations) if policy.actor.encoder is not None else None - ) - next_observation_features = ( - policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None - ) + observation_features = policy.actor.encoder.get_image_features(observations) + next_observation_features = policy.actor.encoder.get_image_features(next_observations) return observation_features, next_observation_features diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py index e10b8766..03a7ec10 100644 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -1,5 +1,3 @@ -import logging -import time from typing import Any import einops @@ -10,7 +8,6 @@ from mani_skill.utils.wrappers.record import RecordEpisode from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv from lerobot.common.envs.configs import ManiskillEnvConfig -from lerobot.configs import parser def preprocess_maniskill_observation( @@ -153,6 +150,27 @@ class TimeLimitWrapper(gym.Wrapper): return super().reset(seed=seed, options=options) +class ManiskillMockGripperWrapper(gym.Wrapper): + def __init__(self, env, nb_discrete_actions: int = 3): + super().__init__(env) + new_shape = env.action_space[0].shape[0] + 1 + new_low = np.concatenate([env.action_space[0].low, [0]]) + new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]]) + action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,)) + self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1])) + + def step(self, action): + if isinstance(action, tuple): + action_agent, telop_action = action + else: + telop_action = 0 + action_agent = action + real_action = action_agent[:-1] + final_action = (real_action, telop_action) + obs, reward, terminated, truncated, info = self.env.step(final_action) + return obs, reward, terminated, truncated, info + + def make_maniskill( cfg: ManiskillEnvConfig, n_envs: int | None = None, @@ -197,40 +215,42 @@ def make_maniskill( env = ManiSkillCompat(env) env = ManiSkillActionWrapper(env) env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control + if cfg.mock_gripper: + env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3) return env -@parser.wrap() -def main(cfg: ManiskillEnvConfig): - """Main function to run the ManiSkill environment.""" - # Create the ManiSkill environment - env = make_maniskill(cfg, n_envs=1) +# @parser.wrap() +# def main(cfg: TrainPipelineConfig): +# """Main function to run the ManiSkill environment.""" +# # Create the ManiSkill environment +# env = make_maniskill(cfg.env, n_envs=1) - # Reset the environment - obs, info = env.reset() +# # Reset the environment +# obs, info = env.reset() - # Run a simple interaction loop - sum_reward = 0 - for i in range(100): - # Sample a random action - action = env.action_space.sample() +# # Run a simple interaction loop +# sum_reward = 0 +# for i in range(100): +# # Sample a random action +# action = env.action_space.sample() - # Step the environment - start_time = time.perf_counter() - obs, reward, terminated, truncated, info = env.step(action) - step_time = time.perf_counter() - start_time - sum_reward += reward - # Log information +# # Step the environment +# start_time = time.perf_counter() +# obs, reward, terminated, truncated, info = env.step(action) +# step_time = time.perf_counter() - start_time +# sum_reward += reward +# # Log information - # Reset if episode terminated - if terminated or truncated: - logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s") - sum_reward = 0 - obs, info = env.reset() +# # Reset if episode terminated +# if terminated or truncated: +# logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s") +# sum_reward = 0 +# obs, info = env.reset() - # Close the environment - env.close() +# # Close the environment +# env.close() # if __name__ == "__main__":