From 4a1c26d9ee7108d86e60feee82b3d20ced026642 Mon Sep 17 00:00:00 2001
From: s1lent4gnt <kmeftah.khalil@gmail.com>
Date: Mon, 31 Mar 2025 17:35:59 +0200
Subject: [PATCH 01/28] Add grasp critic

- Implemented grasp critic to evaluate gripper actions
- Added corresponding config parameters for tuning
---
 .../common/policies/sac/configuration_sac.py  |   8 ++
 lerobot/common/policies/sac/modeling_sac.py   | 124 ++++++++++++++++++
 2 files changed, 132 insertions(+)

diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py
index 906a3bed..e47185da 100644
--- a/lerobot/common/policies/sac/configuration_sac.py
+++ b/lerobot/common/policies/sac/configuration_sac.py
@@ -42,6 +42,14 @@ class CriticNetworkConfig:
     final_activation: str | None = None
 
 
+@dataclass
+class GraspCriticNetworkConfig:
+    hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
+    activate_final: bool = True
+    final_activation: str | None = None
+    output_dim: int = 3
+
+
 @dataclass
 class ActorNetworkConfig:
     hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index f7866714..3589ad25 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -112,6 +112,26 @@ class SACPolicy(
 
         self.critic_ensemble = torch.compile(self.critic_ensemble)
         self.critic_target = torch.compile(self.critic_target)
+
+        # Create grasp critic
+        self.grasp_critic = GraspCritic(
+            encoder=encoder_critic,
+            input_dim=encoder_critic.output_dim,
+            **config.grasp_critic_network_kwargs,
+        )
+
+        # Create target grasp critic
+        self.grasp_critic_target = GraspCritic(
+            encoder=encoder_critic,
+            input_dim=encoder_critic.output_dim,
+            **config.grasp_critic_network_kwargs,
+        )
+
+        self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
+
+        self.grasp_critic = torch.compile(self.grasp_critic)
+        self.grasp_critic_target = torch.compile(self.grasp_critic_target)
+
         self.actor = Policy(
             encoder=encoder_actor,
             network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
@@ -176,6 +196,21 @@ class SACPolicy(
         q_values = critics(observations, actions, observation_features)
         return q_values
 
+    def grasp_critic_forward(self, observations, use_target=False, observation_features=None):
+        """Forward pass through a grasp critic network
+        
+        Args:
+            observations: Dictionary of observations
+            use_target: If True, use target critics, otherwise use ensemble critics
+            observation_features: Optional pre-computed observation features to avoid recomputing encoder output
+
+        Returns:
+            Tensor of Q-values from the grasp critic network
+        """
+        grasp_critic = self.grasp_critic_target if use_target else self.grasp_critic
+        q_values = grasp_critic(observations, observation_features)
+        return q_values
+
     def forward(
         self,
         batch: dict[str, Tensor | dict[str, Tensor]],
@@ -246,6 +281,18 @@ class SACPolicy(
                 + target_param.data * (1.0 - self.config.critic_target_update_weight)
             )
 
+    def update_grasp_target_networks(self):
+        """Update grasp target networks with exponential moving average"""
+        for target_param, param in zip(
+            self.grasp_critic_target.parameters(),
+            self.grasp_critic.parameters(),
+            strict=False,
+        ):
+            target_param.data.copy_(
+                param.data * self.config.critic_target_update_weight
+                + target_param.data * (1.0 - self.config.critic_target_update_weight)
+            )
+    
     def update_temperature(self):
         self.temperature = self.log_alpha.exp().item()
 
@@ -307,6 +354,32 @@ class SACPolicy(
         ).sum()
         return critics_loss
 
+    def compute_loss_grasp_critic(self, observations, actions, rewards, next_observations, done, observation_features=None, next_observation_features=None, complementary_info=None):
+
+        batch_size = rewards.shape[0]
+        grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2)  # Map [-1, 0, 1] -> [0, 1, 2]
+
+        with torch.no_grad():
+            next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False)
+            best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1)
+
+            target_next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=True)
+            target_next_grasp_q = target_next_grasp_qs[torch.arange(batch_size), best_next_grasp_action]
+
+        # Get the grasp penalty from complementary_info
+        grasp_penalty = torch.zeros_like(rewards)
+        if complementary_info is not None and "grasp_penalty" in complementary_info:
+            grasp_penalty = complementary_info["grasp_penalty"]
+
+        grasp_rewards = rewards + grasp_penalty
+        target_grasp_q = grasp_rewards + (1 - done) * self.config.discount * target_next_grasp_q
+
+        predicted_grasp_qs = self.grasp_critic_forward(observations, use_target=False)
+        predicted_grasp_q = predicted_grasp_qs[torch.arange(batch_size), grasp_actions]
+
+        grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q, reduction="mean")
+        return grasp_critic_loss
+
     def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
         """Compute the temperature loss"""
         # calculate temperature loss
@@ -509,6 +582,57 @@ class CriticEnsemble(nn.Module):
         return q_values
 
 
+class GraspCritic(nn.Module):
+    def __init__(
+        self,
+        encoder: Optional[nn.Module],
+        network: nn.Module,
+        output_dim: int = 3,
+        init_final: Optional[float] = None,
+        encoder_is_shared: bool = False,
+    ):
+        super().__init__()
+        self.encoder = encoder
+        self.network = network
+        self.output_dim = output_dim
+
+        # Find the last Linear layer's output dimension
+        for layer in reversed(network.net):
+            if isinstance(layer, nn.Linear):
+                out_features = layer.out_features
+                break
+
+        self.parameters_to_optimize += list(self.network.parameters())
+
+        if self.encoder is not None and not encoder_is_shared:
+            self.parameters_to_optimize += list(self.encoder.parameters())
+
+        self.output_layer = nn.Linear(in_features=out_features, out_features=self.output_dim)
+        if init_final is not None:
+            nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
+            nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
+        else:
+            orthogonal_init()(self.output_layer.weight)
+
+        self.parameters_to_optimize += list(self.output_layer.parameters())
+
+    def forward(
+        self,
+        observations: torch.Tensor,
+        observation_features: torch.Tensor | None = None
+    ) -> torch.Tensor:
+        device = get_device_from_parameters(self)
+        # Move each tensor in observations to device
+        observations = {k: v.to(device) for k, v in observations.items()}
+        # Encode observations if encoder exists
+        obs_enc = (
+            observation_features
+            if observation_features is not None
+            else (observations if self.encoder is None else self.encoder(observations))
+        )
+        return self.output_layer(self.network(obs_enc))
+
+
 class Policy(nn.Module):
     def __init__(
         self,

From 007fee923071a860ff05bf1d7b536375ed6dea5f Mon Sep 17 00:00:00 2001
From: s1lent4gnt <kmeftah.khalil@gmail.com>
Date: Mon, 31 Mar 2025 17:36:35 +0200
Subject: [PATCH 02/28] Add complementary info in the replay buffer

- Added complementary info in the add method
- Added complementary info in the sample method
---
 lerobot/scripts/server/buffer.py | 37 ++++++++++++++++++++++++++++++++
 1 file changed, 37 insertions(+)

diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py
index 776ad9ec..1fbc8803 100644
--- a/lerobot/scripts/server/buffer.py
+++ b/lerobot/scripts/server/buffer.py
@@ -244,6 +244,11 @@ class ReplayBuffer:
 
         self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
         self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
+
+        # Initialize complementary_info storage
+        self.complementary_info_keys = []
+        self.complementary_info_storage = {}
+        
         self.initialized = True
 
     def __len__(self):
@@ -277,6 +282,30 @@ class ReplayBuffer:
         self.dones[self.position] = done
         self.truncateds[self.position] = truncated
 
+        # Store complementary info if provided
+        if complementary_info is not None:
+            # Initialize storage for new keys on first encounter
+            for key, value in complementary_info.items():
+                if key not in self.complementary_info_keys:
+                    self.complementary_info_keys.append(key)
+                    if isinstance(value, torch.Tensor):
+                        shape = value.shape if value.ndim > 0 else (1,)
+                        self.complementary_info_storage[key] = torch.zeros(
+                            (self.capacity, *shape), 
+                            dtype=value.dtype, 
+                            device=self.storage_device
+                        )
+                        
+                # Store the value
+                if key in self.complementary_info_storage:
+                    if isinstance(value, torch.Tensor):
+                        self.complementary_info_storage[key][self.position] = value
+                    else:
+                        # For non-tensor values (like grasp_penalty)
+                        self.complementary_info_storage[key][self.position] = torch.tensor(
+                            value, device=self.storage_device
+                    )
+
         self.position = (self.position + 1) % self.capacity
         self.size = min(self.size + 1, self.capacity)
 
@@ -335,6 +364,13 @@ class ReplayBuffer:
         batch_dones = self.dones[idx].to(self.device).float()
         batch_truncateds = self.truncateds[idx].to(self.device).float()
 
+        # Add complementary_info to batch if it exists
+        batch_complementary_info = {}
+        if hasattr(self, 'complementary_info_keys') and self.complementary_info_keys:
+            for key in self.complementary_info_keys:
+                if key in self.complementary_info_storage:
+                    batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device)
+
         return BatchTransition(
             state=batch_state,
             action=batch_actions,
@@ -342,6 +378,7 @@ class ReplayBuffer:
             next_state=batch_next_state,
             done=batch_dones,
             truncated=batch_truncateds,
+            complementary_info=batch_complementary_info if batch_complementary_info else None,
         )
 
     @classmethod

From 7452f9baaa57d94b1b738ad94acdb23abe57c6d7 Mon Sep 17 00:00:00 2001
From: s1lent4gnt <kmeftah.khalil@gmail.com>
Date: Mon, 31 Mar 2025 17:38:16 +0200
Subject: [PATCH 03/28] Add gripper penalty wrapper

---
 lerobot/scripts/server/gym_manipulator.py | 24 +++++++++++++++++++++++
 1 file changed, 24 insertions(+)

diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py
index 92e8dcbc..26ed1991 100644
--- a/lerobot/scripts/server/gym_manipulator.py
+++ b/lerobot/scripts/server/gym_manipulator.py
@@ -1069,6 +1069,29 @@ class ActionScaleWrapper(gym.ActionWrapper):
         return action * self.scale_vector, is_intervention
 
 
+class GripperPenaltyWrapper(gym.Wrapper):
+    def __init__(self, env, penalty=-0.05):
+        super().__init__(env)
+        self.penalty = penalty
+        self.last_gripper_pos = None
+
+    def reset(self, **kwargs):
+        obs, info = self.env.reset(**kwargs)
+        self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper
+        return obs, info
+
+    def step(self, action):
+        observation, reward, terminated, truncated, info = self.env.step(action)
+
+        if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (action[-1] > 0.5 and self.last_gripper_pos < 0.9):
+            info["grasp_penalty"] = self.penalty
+        else:
+            info["grasp_penalty"] = 0.0
+
+        self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper
+        return observation, reward, terminated, truncated, info
+
+
 def make_robot_env(cfg) -> gym.vector.VectorEnv:
     """
     Factory function to create a vectorized robot environment.
@@ -1144,6 +1167,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
     if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
         env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
     env = BatchCompitableWrapper(env=env)
+    env= GripperPenaltyWrapper(env=env)
 
     return env
 

From 2c1e5fa28b67e05e932ebba2e83ec75c73db3c34 Mon Sep 17 00:00:00 2001
From: s1lent4gnt <kmeftah.khalil@gmail.com>
Date: Mon, 31 Mar 2025 17:40:00 +0200
Subject: [PATCH 04/28] Add get_gripper_action method to GamepadController

---
 .../server/end_effector_control_utils.py      | 25 +++++++++++++++++++
 1 file changed, 25 insertions(+)

diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py
index 3bd927b4..f272426d 100644
--- a/lerobot/scripts/server/end_effector_control_utils.py
+++ b/lerobot/scripts/server/end_effector_control_utils.py
@@ -311,6 +311,31 @@ class GamepadController(InputController):
         except pygame.error:
             logging.error("Error reading gamepad. Is it still connected?")
             return 0.0, 0.0, 0.0
+    
+    def get_gripper_action(self):
+        """
+        Get gripper action using L3/R3 buttons.
+        Press left stick (L3) to open the gripper.
+        Press right stick (R3) to close the gripper.
+        """
+        import pygame
+        
+        try:
+            # Check if buttons are pressed
+            l3_pressed = self.joystick.get_button(9)
+            r3_pressed = self.joystick.get_button(10)
+            
+            # Determine action based on button presses
+            if r3_pressed:
+                return 1.0  # Close gripper
+            elif l3_pressed:
+                return -1.0  # Open gripper
+            else:
+                return 0.0  # No change
+                
+        except pygame.error:
+            logging.error(f"Error reading gamepad. Is it still connected?")
+            return 0.0
 
 
 class GamepadControllerHID(InputController):

From c774bbe5222e376dd142b491fe9c206902a83a99 Mon Sep 17 00:00:00 2001
From: s1lent4gnt <kmeftah.khalil@gmail.com>
Date: Mon, 31 Mar 2025 18:06:21 +0200
Subject: [PATCH 05/28] Add grasp critic to the training loop

- Integrated the grasp critic gradient update to the training loop in learner_server
- Added Adam optimizer and configured grasp critic learning rate in configuration_sac
- Added target critics networks update after the critics gradient step
---
 .../common/policies/sac/configuration_sac.py  |  1 +
 lerobot/common/policies/sac/modeling_sac.py   | 19 ++++++++--
 lerobot/scripts/server/learner_server.py      | 35 +++++++++++++++++++
 3 files changed, 53 insertions(+), 2 deletions(-)

diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py
index e47185da..b1ce30b6 100644
--- a/lerobot/common/policies/sac/configuration_sac.py
+++ b/lerobot/common/policies/sac/configuration_sac.py
@@ -167,6 +167,7 @@ class SACConfig(PreTrainedConfig):
     num_critics: int = 2
     num_subsample_critics: int | None = None
     critic_lr: float = 3e-4
+    grasp_critic_lr: float = 3e-4
     actor_lr: float = 3e-4
     temperature_lr: float = 3e-4
     critic_target_update_weight: float = 0.005
diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index 3589ad25..bd74c65b 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -214,7 +214,7 @@ class SACPolicy(
     def forward(
         self,
         batch: dict[str, Tensor | dict[str, Tensor]],
-        model: Literal["actor", "critic", "temperature"] = "critic",
+        model: Literal["actor", "critic", "grasp_critic", "temperature"] = "critic",
     ) -> dict[str, Tensor]:
         """Compute the loss for the given model
 
@@ -227,7 +227,7 @@ class SACPolicy(
                 - done: Done mask tensor
                 - observation_feature: Optional pre-computed observation features
                 - next_observation_feature: Optional pre-computed next observation features
-            model: Which model to compute the loss for ("actor", "critic", or "temperature")
+            model: Which model to compute the loss for ("actor", "critic", "grasp_critic", or "temperature")
 
         Returns:
             The computed loss tensor
@@ -254,6 +254,21 @@ class SACPolicy(
                 observation_features=observation_features,
                 next_observation_features=next_observation_features,
             )
+        
+        if model == "grasp_critic":
+            # Extract grasp_critic-specific components
+            complementary_info: dict[str, Tensor] = batch["complementary_info"]
+
+            return self.compute_loss_grasp_critic(
+                observations=observations,
+                actions=actions,
+                rewards=rewards,
+                next_observations=next_observations,
+                done=done,
+                observation_features=observation_features,
+                next_observation_features=next_observation_features,
+                complementary_info=complementary_info,
+            )
 
         if model == "actor":
             return self.compute_loss_actor(
diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py
index 98d2dbd8..f79e8d57 100644
--- a/lerobot/scripts/server/learner_server.py
+++ b/lerobot/scripts/server/learner_server.py
@@ -375,6 +375,7 @@ def add_actor_information_and_train(
             observations = batch["state"]
             next_observations = batch["next_state"]
             done = batch["done"]
+            complementary_info = batch["complementary_info"]
             check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
 
             observation_features, next_observation_features = get_observation_features(
@@ -390,6 +391,7 @@ def add_actor_information_and_train(
                 "done": done,
                 "observation_feature": observation_features,
                 "next_observation_feature": next_observation_features,
+                "complementary_info": complementary_info,
             }
 
             # Use the forward method for critic loss
@@ -404,7 +406,20 @@ def add_actor_information_and_train(
 
             optimizers["critic"].step()
 
+            # Add gripper critic optimization
+            loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic")
+            optimizers["grasp_critic"].zero_grad()
+            loss_grasp_critic.backward()
+
+            # clip gradients
+            grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
+                parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
+            )
+
+            optimizers["grasp_critic"].step()
+            
             policy.update_target_networks()
+            policy.update_grasp_target_networks()
 
         batch = replay_buffer.sample(batch_size=batch_size)
 
@@ -435,6 +450,7 @@ def add_actor_information_and_train(
             "done": done,
             "observation_feature": observation_features,
             "next_observation_feature": next_observation_features,
+            "complementary_info": complementary_info,
         }
 
         # Use the forward method for critic loss
@@ -453,6 +469,22 @@ def add_actor_information_and_train(
         training_infos["loss_critic"] = loss_critic.item()
         training_infos["critic_grad_norm"] = critic_grad_norm
 
+        # Add gripper critic optimization
+        loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic")
+        optimizers["grasp_critic"].zero_grad()
+        loss_grasp_critic.backward()
+
+        # clip gradients
+        grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
+            parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
+        )
+
+        optimizers["grasp_critic"].step()
+
+        # Add training info for the grasp critic
+        training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
+        training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
+
         if optimization_step % policy_update_freq == 0:
             for _ in range(policy_update_freq):
                 # Use the forward method for actor loss
@@ -495,6 +527,7 @@ def add_actor_information_and_train(
             last_time_policy_pushed = time.time()
 
         policy.update_target_networks()
+        policy.update_grasp_target_networks()
 
         # Log training metrics at specified intervals
         if optimization_step % log_freq == 0:
@@ -729,11 +762,13 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
         lr=cfg.policy.actor_lr,
     )
     optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
+    optimizer_grasp_critic = torch.optim.Adam(params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr)
     optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
     lr_scheduler = None
     optimizers = {
         "actor": optimizer_actor,
         "critic": optimizer_critic,
+        "grasp_critic": optimizer_grasp_critic,
         "temperature": optimizer_temperature,
     }
     return optimizers, lr_scheduler

From 7983baf4fcbabe77e2d946c2fe3f441e6581ebfc Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Mon, 31 Mar 2025 16:10:00 +0000
Subject: [PATCH 06/28] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 lerobot/common/policies/sac/modeling_sac.py   | 23 ++++++++++++-------
 lerobot/scripts/server/buffer.py              | 12 ++++------
 .../server/end_effector_control_utils.py      | 10 ++++----
 lerobot/scripts/server/gym_manipulator.py     | 10 ++++----
 lerobot/scripts/server/learner_server.py      |  6 +++--
 5 files changed, 35 insertions(+), 26 deletions(-)

diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index bd74c65b..95ea3928 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -198,7 +198,7 @@ class SACPolicy(
 
     def grasp_critic_forward(self, observations, use_target=False, observation_features=None):
         """Forward pass through a grasp critic network
-        
+
         Args:
             observations: Dictionary of observations
             use_target: If True, use target critics, otherwise use ensemble critics
@@ -254,7 +254,7 @@ class SACPolicy(
                 observation_features=observation_features,
                 next_observation_features=next_observation_features,
             )
-        
+
         if model == "grasp_critic":
             # Extract grasp_critic-specific components
             complementary_info: dict[str, Tensor] = batch["complementary_info"]
@@ -307,7 +307,7 @@ class SACPolicy(
                 param.data * self.config.critic_target_update_weight
                 + target_param.data * (1.0 - self.config.critic_target_update_weight)
             )
-    
+
     def update_temperature(self):
         self.temperature = self.log_alpha.exp().item()
 
@@ -369,8 +369,17 @@ class SACPolicy(
         ).sum()
         return critics_loss
 
-    def compute_loss_grasp_critic(self, observations, actions, rewards, next_observations, done, observation_features=None, next_observation_features=None, complementary_info=None):
-
+    def compute_loss_grasp_critic(
+        self,
+        observations,
+        actions,
+        rewards,
+        next_observations,
+        done,
+        observation_features=None,
+        next_observation_features=None,
+        complementary_info=None,
+    ):
         batch_size = rewards.shape[0]
         grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2)  # Map [-1, 0, 1] -> [0, 1, 2]
 
@@ -632,9 +641,7 @@ class GraspCritic(nn.Module):
         self.parameters_to_optimize += list(self.output_layer.parameters())
 
     def forward(
-        self,
-        observations: torch.Tensor,
-        observation_features: torch.Tensor | None = None
+        self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
     ) -> torch.Tensor:
         device = get_device_from_parameters(self)
         # Move each tensor in observations to device
diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py
index 1fbc8803..bf65b1ec 100644
--- a/lerobot/scripts/server/buffer.py
+++ b/lerobot/scripts/server/buffer.py
@@ -248,7 +248,7 @@ class ReplayBuffer:
         # Initialize complementary_info storage
         self.complementary_info_keys = []
         self.complementary_info_storage = {}
-        
+
         self.initialized = True
 
     def __len__(self):
@@ -291,11 +291,9 @@ class ReplayBuffer:
                     if isinstance(value, torch.Tensor):
                         shape = value.shape if value.ndim > 0 else (1,)
                         self.complementary_info_storage[key] = torch.zeros(
-                            (self.capacity, *shape), 
-                            dtype=value.dtype, 
-                            device=self.storage_device
+                            (self.capacity, *shape), dtype=value.dtype, device=self.storage_device
                         )
-                        
+
                 # Store the value
                 if key in self.complementary_info_storage:
                     if isinstance(value, torch.Tensor):
@@ -304,7 +302,7 @@ class ReplayBuffer:
                         # For non-tensor values (like grasp_penalty)
                         self.complementary_info_storage[key][self.position] = torch.tensor(
                             value, device=self.storage_device
-                    )
+                        )
 
         self.position = (self.position + 1) % self.capacity
         self.size = min(self.size + 1, self.capacity)
@@ -366,7 +364,7 @@ class ReplayBuffer:
 
         # Add complementary_info to batch if it exists
         batch_complementary_info = {}
-        if hasattr(self, 'complementary_info_keys') and self.complementary_info_keys:
+        if hasattr(self, "complementary_info_keys") and self.complementary_info_keys:
             for key in self.complementary_info_keys:
                 if key in self.complementary_info_storage:
                     batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device)
diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py
index f272426d..3b2cfa90 100644
--- a/lerobot/scripts/server/end_effector_control_utils.py
+++ b/lerobot/scripts/server/end_effector_control_utils.py
@@ -311,7 +311,7 @@ class GamepadController(InputController):
         except pygame.error:
             logging.error("Error reading gamepad. Is it still connected?")
             return 0.0, 0.0, 0.0
-    
+
     def get_gripper_action(self):
         """
         Get gripper action using L3/R3 buttons.
@@ -319,12 +319,12 @@ class GamepadController(InputController):
         Press right stick (R3) to close the gripper.
         """
         import pygame
-        
+
         try:
             # Check if buttons are pressed
             l3_pressed = self.joystick.get_button(9)
             r3_pressed = self.joystick.get_button(10)
-            
+
             # Determine action based on button presses
             if r3_pressed:
                 return 1.0  # Close gripper
@@ -332,9 +332,9 @@ class GamepadController(InputController):
                 return -1.0  # Open gripper
             else:
                 return 0.0  # No change
-                
+
         except pygame.error:
-            logging.error(f"Error reading gamepad. Is it still connected?")
+            logging.error("Error reading gamepad. Is it still connected?")
             return 0.0
 
 
diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py
index 26ed1991..ac3bbb0a 100644
--- a/lerobot/scripts/server/gym_manipulator.py
+++ b/lerobot/scripts/server/gym_manipulator.py
@@ -1077,18 +1077,20 @@ class GripperPenaltyWrapper(gym.Wrapper):
 
     def reset(self, **kwargs):
         obs, info = self.env.reset(**kwargs)
-        self.last_gripper_pos = obs["observation.state"][0, 0] # first idx for the gripper
+        self.last_gripper_pos = obs["observation.state"][0, 0]  # first idx for the gripper
         return obs, info
 
     def step(self, action):
         observation, reward, terminated, truncated, info = self.env.step(action)
 
-        if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (action[-1] > 0.5 and self.last_gripper_pos < 0.9):
+        if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (
+            action[-1] > 0.5 and self.last_gripper_pos < 0.9
+        ):
             info["grasp_penalty"] = self.penalty
         else:
             info["grasp_penalty"] = 0.0
 
-        self.last_gripper_pos = observation["observation.state"][0, 0] # first idx for the gripper
+        self.last_gripper_pos = observation["observation.state"][0, 0]  # first idx for the gripper
         return observation, reward, terminated, truncated, info
 
 
@@ -1167,7 +1169,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
     if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
         env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
     env = BatchCompitableWrapper(env=env)
-    env= GripperPenaltyWrapper(env=env)
+    env = GripperPenaltyWrapper(env=env)
 
     return env
 
diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py
index f79e8d57..0f760dc5 100644
--- a/lerobot/scripts/server/learner_server.py
+++ b/lerobot/scripts/server/learner_server.py
@@ -417,7 +417,7 @@ def add_actor_information_and_train(
             )
 
             optimizers["grasp_critic"].step()
-            
+
             policy.update_target_networks()
             policy.update_grasp_target_networks()
 
@@ -762,7 +762,9 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
         lr=cfg.policy.actor_lr,
     )
     optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
-    optimizer_grasp_critic = torch.optim.Adam(params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr)
+    optimizer_grasp_critic = torch.optim.Adam(
+        params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr
+    )
     optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
     lr_scheduler = None
     optimizers = {

From fe2ff516a8d354d33283163137a4a266145caa7d Mon Sep 17 00:00:00 2001
From: Michel Aractingi <michel.aractingi@huggingface.co>
Date: Tue, 1 Apr 2025 11:08:15 +0200
Subject: [PATCH 07/28] Added Gripper quantization wrapper and grasp penalty
 removed complementary info from buffer and learner server removed
 get_gripper_action function added gripper parameters to
 `common/envs/configs.py`

---
 lerobot/common/envs/configs.py                |   3 +
 lerobot/scripts/server/buffer.py              |  34 ------
 .../server/end_effector_control_utils.py      |  25 -----
 lerobot/scripts/server/gym_manipulator.py     | 100 +++++++++++-------
 lerobot/scripts/server/learner_server.py      |   3 -
 5 files changed, 66 insertions(+), 99 deletions(-)

diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py
index 825fa162..440512c3 100644
--- a/lerobot/common/envs/configs.py
+++ b/lerobot/common/envs/configs.py
@@ -203,6 +203,9 @@ class EnvWrapperConfig:
     joint_masking_action_space: Optional[Any] = None
     ee_action_space_params: Optional[EEActionSpaceConfig] = None
     use_gripper: bool = False
+    gripper_quantization_threshold: float = 0.8
+    gripper_penalty: float = 0.0
+    open_gripper_on_reset: bool = False
 
 
 @EnvConfig.register_subclass(name="gym_manipulator")
diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py
index bf65b1ec..2af3995e 100644
--- a/lerobot/scripts/server/buffer.py
+++ b/lerobot/scripts/server/buffer.py
@@ -245,10 +245,6 @@ class ReplayBuffer:
         self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
         self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
 
-        # Initialize complementary_info storage
-        self.complementary_info_keys = []
-        self.complementary_info_storage = {}
-
         self.initialized = True
 
     def __len__(self):
@@ -282,28 +278,6 @@ class ReplayBuffer:
         self.dones[self.position] = done
         self.truncateds[self.position] = truncated
 
-        # Store complementary info if provided
-        if complementary_info is not None:
-            # Initialize storage for new keys on first encounter
-            for key, value in complementary_info.items():
-                if key not in self.complementary_info_keys:
-                    self.complementary_info_keys.append(key)
-                    if isinstance(value, torch.Tensor):
-                        shape = value.shape if value.ndim > 0 else (1,)
-                        self.complementary_info_storage[key] = torch.zeros(
-                            (self.capacity, *shape), dtype=value.dtype, device=self.storage_device
-                        )
-
-                # Store the value
-                if key in self.complementary_info_storage:
-                    if isinstance(value, torch.Tensor):
-                        self.complementary_info_storage[key][self.position] = value
-                    else:
-                        # For non-tensor values (like grasp_penalty)
-                        self.complementary_info_storage[key][self.position] = torch.tensor(
-                            value, device=self.storage_device
-                        )
-
         self.position = (self.position + 1) % self.capacity
         self.size = min(self.size + 1, self.capacity)
 
@@ -362,13 +336,6 @@ class ReplayBuffer:
         batch_dones = self.dones[idx].to(self.device).float()
         batch_truncateds = self.truncateds[idx].to(self.device).float()
 
-        # Add complementary_info to batch if it exists
-        batch_complementary_info = {}
-        if hasattr(self, "complementary_info_keys") and self.complementary_info_keys:
-            for key in self.complementary_info_keys:
-                if key in self.complementary_info_storage:
-                    batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device)
-
         return BatchTransition(
             state=batch_state,
             action=batch_actions,
@@ -376,7 +343,6 @@ class ReplayBuffer:
             next_state=batch_next_state,
             done=batch_dones,
             truncated=batch_truncateds,
-            complementary_info=batch_complementary_info if batch_complementary_info else None,
         )
 
     @classmethod
diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py
index 3b2cfa90..3bd927b4 100644
--- a/lerobot/scripts/server/end_effector_control_utils.py
+++ b/lerobot/scripts/server/end_effector_control_utils.py
@@ -312,31 +312,6 @@ class GamepadController(InputController):
             logging.error("Error reading gamepad. Is it still connected?")
             return 0.0, 0.0, 0.0
 
-    def get_gripper_action(self):
-        """
-        Get gripper action using L3/R3 buttons.
-        Press left stick (L3) to open the gripper.
-        Press right stick (R3) to close the gripper.
-        """
-        import pygame
-
-        try:
-            # Check if buttons are pressed
-            l3_pressed = self.joystick.get_button(9)
-            r3_pressed = self.joystick.get_button(10)
-
-            # Determine action based on button presses
-            if r3_pressed:
-                return 1.0  # Close gripper
-            elif l3_pressed:
-                return -1.0  # Open gripper
-            else:
-                return 0.0  # No change
-
-        except pygame.error:
-            logging.error("Error reading gamepad. Is it still connected?")
-            return 0.0
-
 
 class GamepadControllerHID(InputController):
     """Generate motion deltas from gamepad input using HIDAPI."""
diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py
index ac3bbb0a..3aa75466 100644
--- a/lerobot/scripts/server/gym_manipulator.py
+++ b/lerobot/scripts/server/gym_manipulator.py
@@ -761,6 +761,62 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
         return observation
 
 
+class GripperPenaltyWrapper(gym.RewardWrapper):
+    def __init__(self, env, penalty: float = -0.1):
+        super().__init__(env)
+        self.penalty = penalty
+        self.last_gripper_state = None
+
+    def reward(self, reward, action):
+        gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
+
+        if isinstance(action, tuple):
+            action = action[0]
+        action_normalized = action[-1] / MAX_GRIPPER_COMMAND
+
+        gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or (
+            gripper_state_normalized > 0.9 and action_normalized < 0.1
+        )
+        breakpoint()
+
+        return reward + self.penalty * gripper_penalty_bool
+
+    def step(self, action):
+        self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
+        obs, reward, terminated, truncated, info = self.env.step(action)
+        reward = self.reward(reward, action)
+        return obs, reward, terminated, truncated, info
+
+    def reset(self, **kwargs):
+        self.last_gripper_state = None
+        return super().reset(**kwargs)
+
+
+class GripperQuantizationWrapper(gym.ActionWrapper):
+    def __init__(self, env, quantization_threshold: float = 0.2):
+        super().__init__(env)
+        self.quantization_threshold = quantization_threshold
+
+    def action(self, action):
+        is_intervention = False
+        if isinstance(action, tuple):
+            action, is_intervention = action
+
+        gripper_command = action[-1]
+        # Quantize gripper command to -1, 0 or 1
+        if gripper_command < -self.quantization_threshold:
+            gripper_command = -MAX_GRIPPER_COMMAND
+        elif gripper_command > self.quantization_threshold:
+            gripper_command = MAX_GRIPPER_COMMAND
+        else:
+            gripper_command = 0.0
+
+        gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
+        gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
+        action[-1] = gripper_action.item()
+        return action, is_intervention
+
+
 class EEActionWrapper(gym.ActionWrapper):
     def __init__(self, env, ee_action_space_params=None, use_gripper=False):
         super().__init__(env)
@@ -820,17 +876,7 @@ class EEActionWrapper(gym.ActionWrapper):
             fk_func=self.fk_function,
         )
         if self.use_gripper:
-            # Quantize gripper command to -1, 0 or 1
-            if gripper_command < -0.2:
-                gripper_command = -1.0
-            elif gripper_command > 0.2:
-                gripper_command = 1.0
-            else:
-                gripper_command = 0.0
-
-            gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
-            gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
-            target_joint_pos[-1] = gripper_action
+            target_joint_pos[-1] = gripper_command
 
         return target_joint_pos, is_intervention
 
@@ -1069,31 +1115,6 @@ class ActionScaleWrapper(gym.ActionWrapper):
         return action * self.scale_vector, is_intervention
 
 
-class GripperPenaltyWrapper(gym.Wrapper):
-    def __init__(self, env, penalty=-0.05):
-        super().__init__(env)
-        self.penalty = penalty
-        self.last_gripper_pos = None
-
-    def reset(self, **kwargs):
-        obs, info = self.env.reset(**kwargs)
-        self.last_gripper_pos = obs["observation.state"][0, 0]  # first idx for the gripper
-        return obs, info
-
-    def step(self, action):
-        observation, reward, terminated, truncated, info = self.env.step(action)
-
-        if (action[-1] < -0.5 and self.last_gripper_pos > 0.9) or (
-            action[-1] > 0.5 and self.last_gripper_pos < 0.9
-        ):
-            info["grasp_penalty"] = self.penalty
-        else:
-            info["grasp_penalty"] = 0.0
-
-        self.last_gripper_pos = observation["observation.state"][0, 0]  # first idx for the gripper
-        return observation, reward, terminated, truncated, info
-
-
 def make_robot_env(cfg) -> gym.vector.VectorEnv:
     """
     Factory function to create a vectorized robot environment.
@@ -1143,6 +1164,12 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
     # Add reward computation and control wrappers
     # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
     env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
+    if cfg.wrapper.use_gripper:
+        env = GripperQuantizationWrapper(
+            env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
+        )
+        # env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
+
     if cfg.wrapper.ee_action_space_params is not None:
         env = EEActionWrapper(
             env=env,
@@ -1169,7 +1196,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
     if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
         env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
     env = BatchCompitableWrapper(env=env)
-    env = GripperPenaltyWrapper(env=env)
 
     return env
 
diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py
index 0f760dc5..15de2cb7 100644
--- a/lerobot/scripts/server/learner_server.py
+++ b/lerobot/scripts/server/learner_server.py
@@ -375,7 +375,6 @@ def add_actor_information_and_train(
             observations = batch["state"]
             next_observations = batch["next_state"]
             done = batch["done"]
-            complementary_info = batch["complementary_info"]
             check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
 
             observation_features, next_observation_features = get_observation_features(
@@ -391,7 +390,6 @@ def add_actor_information_and_train(
                 "done": done,
                 "observation_feature": observation_features,
                 "next_observation_feature": next_observation_features,
-                "complementary_info": complementary_info,
             }
 
             # Use the forward method for critic loss
@@ -450,7 +448,6 @@ def add_actor_information_and_train(
             "done": done,
             "observation_feature": observation_features,
             "next_observation_feature": next_observation_features,
-            "complementary_info": complementary_info,
         }
 
         # Use the forward method for critic loss

From 6a215f47ddf7e3235736f60cefb4ffb42166406c Mon Sep 17 00:00:00 2001
From: AdilZouitine <adilzouitinegm@gmail.com>
Date: Tue, 1 Apr 2025 09:30:32 +0000
Subject: [PATCH 08/28] Refactor SAC configuration and policy to support
 discrete actions

- Removed GraspCriticNetworkConfig class and integrated its parameters into SACConfig.
- Added num_discrete_actions parameter to SACConfig for better action handling.
- Updated SACPolicy to conditionally create grasp critic networks based on num_discrete_actions.
- Enhanced grasp critic forward pass to handle discrete actions and compute losses accordingly.
---
 .../common/policies/sac/configuration_sac.py  |   9 +-
 lerobot/common/policies/sac/modeling_sac.py   | 103 +++++++++++-------
 2 files changed, 68 insertions(+), 44 deletions(-)

diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py
index b1ce30b6..66d9aa45 100644
--- a/lerobot/common/policies/sac/configuration_sac.py
+++ b/lerobot/common/policies/sac/configuration_sac.py
@@ -42,12 +42,6 @@ class CriticNetworkConfig:
     final_activation: str | None = None
 
 
-@dataclass
-class GraspCriticNetworkConfig:
-    hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
-    activate_final: bool = True
-    final_activation: str | None = None
-    output_dim: int = 3
 
 
 @dataclass
@@ -152,6 +146,7 @@ class SACConfig(PreTrainedConfig):
     freeze_vision_encoder: bool = True
     image_encoder_hidden_dim: int = 32
     shared_encoder: bool = True
+    num_discrete_actions: int | None = None
 
     # Training parameter
     online_steps: int = 1000000
@@ -182,7 +177,7 @@ class SACConfig(PreTrainedConfig):
     critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
     actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
     policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
-
+    grasp_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
     actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
     concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
 
diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index 95ea3928..dd156918 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -33,6 +33,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy
 from lerobot.common.policies.sac.configuration_sac import SACConfig
 from lerobot.common.policies.utils import get_device_from_parameters
 
+DISCRETE_DIMENSION_INDEX = -1   
 
 class SACPolicy(
     PreTrainedPolicy,
@@ -113,24 +114,30 @@ class SACPolicy(
         self.critic_ensemble = torch.compile(self.critic_ensemble)
         self.critic_target = torch.compile(self.critic_target)
 
-        # Create grasp critic
-        self.grasp_critic = GraspCritic(
-            encoder=encoder_critic,
-            input_dim=encoder_critic.output_dim,
-            **config.grasp_critic_network_kwargs,
-        )
+        self.grasp_critic = None
+        self.grasp_critic_target = None
 
-        # Create target grasp critic
-        self.grasp_critic_target = GraspCritic(
-            encoder=encoder_critic,
-            input_dim=encoder_critic.output_dim,
-            **config.grasp_critic_network_kwargs,
-        )
+        if config.num_discrete_actions is not None:
+            # Create grasp critic
+            self.grasp_critic = GraspCritic(
+                encoder=encoder_critic,
+                input_dim=encoder_critic.output_dim,
+                output_dim=config.num_discrete_actions,
+                **asdict(config.grasp_critic_network_kwargs),
+            )
 
-        self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
+            # Create target grasp critic
+            self.grasp_critic_target = GraspCritic(
+                encoder=encoder_critic,
+                input_dim=encoder_critic.output_dim,
+                output_dim=config.num_discrete_actions,
+                **asdict(config.grasp_critic_network_kwargs),
+            )
 
-        self.grasp_critic = torch.compile(self.grasp_critic)
-        self.grasp_critic_target = torch.compile(self.grasp_critic_target)
+            self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
+
+            self.grasp_critic = torch.compile(self.grasp_critic)
+            self.grasp_critic_target = torch.compile(self.grasp_critic_target)
 
         self.actor = Policy(
             encoder=encoder_actor,
@@ -173,6 +180,12 @@ class SACPolicy(
         """Select action for inference/evaluation"""
         actions, _, _ = self.actor(batch)
         actions = self.unnormalize_outputs({"action": actions})["action"]
+
+        if self.config.num_discrete_actions is not None:
+            discrete_action_value = self.grasp_critic(batch)
+            discrete_action = torch.argmax(discrete_action_value, dim=-1)
+            actions = torch.cat([actions, discrete_action], dim=-1)
+
         return actions
 
     def critic_forward(
@@ -192,11 +205,12 @@ class SACPolicy(
         Returns:
             Tensor of Q-values from all critics
         """
+
         critics = self.critic_target if use_target else self.critic_ensemble
         q_values = critics(observations, actions, observation_features)
         return q_values
 
-    def grasp_critic_forward(self, observations, use_target=False, observation_features=None):
+    def grasp_critic_forward(self, observations, use_target=False, observation_features=None) -> torch.Tensor:
         """Forward pass through a grasp critic network
 
         Args:
@@ -256,9 +270,6 @@ class SACPolicy(
             )
 
         if model == "grasp_critic":
-            # Extract grasp_critic-specific components
-            complementary_info: dict[str, Tensor] = batch["complementary_info"]
-
             return self.compute_loss_grasp_critic(
                 observations=observations,
                 actions=actions,
@@ -267,7 +278,6 @@ class SACPolicy(
                 done=done,
                 observation_features=observation_features,
                 next_observation_features=next_observation_features,
-                complementary_info=complementary_info,
             )
 
         if model == "actor":
@@ -349,6 +359,12 @@ class SACPolicy(
             td_target = rewards + (1 - done) * self.config.discount * min_q
 
         # 3- compute predicted qs
+        if self.config.num_discrete_actions is not None:
+            # NOTE: We only want to keep the continuous action part
+            # In the buffer we have the full action space (continuous + discrete)
+            # We need to split them before concatenating them in the critic forward
+            actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
+            
         q_preds = self.critic_forward(
             observations=observations,
             actions=actions,
@@ -378,30 +394,43 @@ class SACPolicy(
         done,
         observation_features=None,
         next_observation_features=None,
-        complementary_info=None,
     ):
-        batch_size = rewards.shape[0]
-        grasp_actions = torch.clip(actions[:, -1].long() + 1, 0, 2)  # Map [-1, 0, 1] -> [0, 1, 2]
+        # NOTE: We only want to keep the discrete action part
+        # In the buffer we have the full action space (continuous + discrete)
+        # We need to split them before concatenating them in the critic forward
+        actions: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:]
+        actions = actions.long()
 
         with torch.no_grad():
+            # For DQN, select actions using online network, evaluate with target network
             next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False)
             best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1)
+            
+            # Get target Q-values from target network
+            target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True)
+            
+            # Use gather to select Q-values for best actions
+            target_next_grasp_q = torch.gather(
+                target_next_grasp_qs, 
+                dim=1, 
+                index=best_next_grasp_action.unsqueeze(-1)
+            ).squeeze(-1)
 
-            target_next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=True)
-            target_next_grasp_q = target_next_grasp_qs[torch.arange(batch_size), best_next_grasp_action]
+        # Compute target Q-value with Bellman equation
+        target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q
 
-        # Get the grasp penalty from complementary_info
-        grasp_penalty = torch.zeros_like(rewards)
-        if complementary_info is not None and "grasp_penalty" in complementary_info:
-            grasp_penalty = complementary_info["grasp_penalty"]
+        # Get predicted Q-values for current observations
+        predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False)
+        
+        # Use gather to select Q-values for taken actions
+        predicted_grasp_q = torch.gather(
+            predicted_grasp_qs,
+            dim=1,
+            index=actions.unsqueeze(-1)
+        ).squeeze(-1)
 
-        grasp_rewards = rewards + grasp_penalty
-        target_grasp_q = grasp_rewards + (1 - done) * self.config.discount * target_next_grasp_q
-
-        predicted_grasp_qs = self.grasp_critic_forward(observations, use_target=False)
-        predicted_grasp_q = predicted_grasp_qs[torch.arange(batch_size), grasp_actions]
-
-        grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q, reduction="mean")
+        # Compute MSE loss between predicted and target Q-values
+        grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q)
         return grasp_critic_loss
 
     def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
@@ -611,7 +640,7 @@ class GraspCritic(nn.Module):
         self,
         encoder: Optional[nn.Module],
         network: nn.Module,
-        output_dim: int = 3,
+        output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that
         init_final: Optional[float] = None,
         encoder_is_shared: bool = False,
     ):

From 306c735172a6ffd0d1f1fa0866f29abd8d2c597c Mon Sep 17 00:00:00 2001
From: AdilZouitine <adilzouitinegm@gmail.com>
Date: Tue, 1 Apr 2025 11:42:28 +0000
Subject: [PATCH 09/28] Refactor SAC policy and training loop to enhance
 discrete action support

- Updated SACPolicy to conditionally compute losses for grasp critic based on num_discrete_actions.
- Simplified forward method to return loss outputs as a dictionary for better clarity.
- Adjusted learner_server to handle both main and grasp critic losses during training.
- Ensured optimizers are created conditionally for grasp critic based on configuration settings.
---
 .../common/policies/sac/configuration_sac.py  |   2 +-
 lerobot/common/policies/sac/modeling_sac.py   |  54 ++++----
 lerobot/scripts/server/learner_server.py      | 120 +++++++++---------
 3 files changed, 86 insertions(+), 90 deletions(-)

diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py
index 66d9aa45..ae38b1c5 100644
--- a/lerobot/common/policies/sac/configuration_sac.py
+++ b/lerobot/common/policies/sac/configuration_sac.py
@@ -87,6 +87,7 @@ class SACConfig(PreTrainedConfig):
         freeze_vision_encoder: Whether to freeze the vision encoder during training.
         image_encoder_hidden_dim: Hidden dimension size for the image encoder.
         shared_encoder: Whether to use a shared encoder for actor and critic.
+        num_discrete_actions: Number of discrete actions, eg for gripper actions.
         concurrency: Configuration for concurrency settings.
         actor_learner: Configuration for actor-learner architecture.
         online_steps: Number of steps for online training.
@@ -162,7 +163,6 @@ class SACConfig(PreTrainedConfig):
     num_critics: int = 2
     num_subsample_critics: int | None = None
     critic_lr: float = 3e-4
-    grasp_critic_lr: float = 3e-4
     actor_lr: float = 3e-4
     temperature_lr: float = 3e-4
     critic_target_update_weight: float = 0.005
diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index dd156918..d0e8b25d 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -228,7 +228,7 @@ class SACPolicy(
     def forward(
         self,
         batch: dict[str, Tensor | dict[str, Tensor]],
-        model: Literal["actor", "critic", "grasp_critic", "temperature"] = "critic",
+        model: Literal["actor", "critic", "temperature"] = "critic",
     ) -> dict[str, Tensor]:
         """Compute the loss for the given model
 
@@ -246,7 +246,6 @@ class SACPolicy(
         Returns:
             The computed loss tensor
         """
-        # TODO: (maractingi, azouitine) Respect the function signature we output tensors
         # Extract common components from batch
         actions: Tensor = batch["action"]
         observations: dict[str, Tensor] = batch["state"]
@@ -259,7 +258,7 @@ class SACPolicy(
             done: Tensor = batch["done"]
             next_observation_features: Tensor = batch.get("next_observation_feature")
 
-            return self.compute_loss_critic(
+            loss_critic =  self.compute_loss_critic(
                 observations=observations,
                 actions=actions,
                 rewards=rewards,
@@ -268,29 +267,28 @@ class SACPolicy(
                 observation_features=observation_features,
                 next_observation_features=next_observation_features,
             )
+            if self.config.num_discrete_actions is not None:
+                loss_grasp_critic = self.compute_loss_grasp_critic(
+                    observations=observations,
+                    actions=actions,
+                    rewards=rewards,
+                    next_observations=next_observations,
+                    done=done,
+                )
+            return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
 
-        if model == "grasp_critic":
-            return self.compute_loss_grasp_critic(
-                observations=observations,
-                actions=actions,
-                rewards=rewards,
-                next_observations=next_observations,
-                done=done,
-                observation_features=observation_features,
-                next_observation_features=next_observation_features,
-            )
 
         if model == "actor":
-            return self.compute_loss_actor(
+            return {"loss_actor": self.compute_loss_actor(
                 observations=observations,
                 observation_features=observation_features,
-            )
+            )}
 
         if model == "temperature":
-            return self.compute_loss_temperature(
+            return {"loss_temperature": self.compute_loss_temperature(
                 observations=observations,
                 observation_features=observation_features,
-            )
+            )}
 
         raise ValueError(f"Unknown model type: {model}")
 
@@ -305,18 +303,16 @@ class SACPolicy(
                 param.data * self.config.critic_target_update_weight
                 + target_param.data * (1.0 - self.config.critic_target_update_weight)
             )
-
-    def update_grasp_target_networks(self):
-        """Update grasp target networks with exponential moving average"""
-        for target_param, param in zip(
-            self.grasp_critic_target.parameters(),
-            self.grasp_critic.parameters(),
-            strict=False,
-        ):
-            target_param.data.copy_(
-                param.data * self.config.critic_target_update_weight
-                + target_param.data * (1.0 - self.config.critic_target_update_weight)
-            )
+        if self.config.num_discrete_actions is not None:
+            for target_param, param in zip(
+                self.grasp_critic_target.parameters(),
+                self.grasp_critic.parameters(),
+                strict=False,
+            ):
+                target_param.data.copy_(
+                    param.data * self.config.critic_target_update_weight
+                    + target_param.data * (1.0 - self.config.critic_target_update_weight)
+                )
 
     def update_temperature(self):
         self.temperature = self.log_alpha.exp().item()
diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py
index 15de2cb7..627a1a17 100644
--- a/lerobot/scripts/server/learner_server.py
+++ b/lerobot/scripts/server/learner_server.py
@@ -392,32 +392,30 @@ def add_actor_information_and_train(
                 "next_observation_feature": next_observation_features,
             }
 
-            # Use the forward method for critic loss
-            loss_critic = policy.forward(forward_batch, model="critic")
+            # Use the forward method for critic loss (includes both main critic and grasp critic)
+            critic_output = policy.forward(forward_batch, model="critic")
+            
+            # Main critic optimization
+            loss_critic = critic_output["loss_critic"]
             optimizers["critic"].zero_grad()
             loss_critic.backward()
-
-            # clip gradients
             critic_grad_norm = torch.nn.utils.clip_grad_norm_(
                 parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
             )
-
             optimizers["critic"].step()
 
-            # Add gripper critic optimization
-            loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic")
-            optimizers["grasp_critic"].zero_grad()
-            loss_grasp_critic.backward()
-
-            # clip gradients
-            grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
-                parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
-            )
-
-            optimizers["grasp_critic"].step()
+            # Grasp critic optimization (if available)
+            if "loss_grasp_critic" in critic_output and hasattr(policy, "grasp_critic"):
+                loss_grasp_critic = critic_output["loss_grasp_critic"]
+                optimizers["grasp_critic"].zero_grad()
+                loss_grasp_critic.backward()
+                grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
+                    parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
+                )
+                optimizers["grasp_critic"].step()
 
+            # Update target networks
             policy.update_target_networks()
-            policy.update_grasp_target_networks()
 
         batch = replay_buffer.sample(batch_size=batch_size)
 
@@ -450,81 +448,80 @@ def add_actor_information_and_train(
             "next_observation_feature": next_observation_features,
         }
 
-        # Use the forward method for critic loss
-        loss_critic = policy.forward(forward_batch, model="critic")
+        # Use the forward method for critic loss (includes both main critic and grasp critic)
+        critic_output = policy.forward(forward_batch, model="critic")
+        
+        # Main critic optimization
+        loss_critic = critic_output["loss_critic"]
         optimizers["critic"].zero_grad()
         loss_critic.backward()
-
-        # clip gradients
         critic_grad_norm = torch.nn.utils.clip_grad_norm_(
             parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
         ).item()
-
         optimizers["critic"].step()
 
-        training_infos = {}
-        training_infos["loss_critic"] = loss_critic.item()
-        training_infos["critic_grad_norm"] = critic_grad_norm
+        # Initialize training info dictionary
+        training_infos = {
+            "loss_critic": loss_critic.item(),
+            "critic_grad_norm": critic_grad_norm,
+        }
 
-        # Add gripper critic optimization
-        loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic")
-        optimizers["grasp_critic"].zero_grad()
-        loss_grasp_critic.backward()
-
-        # clip gradients
-        grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
-            parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
-        )
-
-        optimizers["grasp_critic"].step()
-
-        # Add training info for the grasp critic
-        training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
-        training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
+        # Grasp critic optimization (if available)
+        if "loss_grasp_critic" in critic_output:
+            loss_grasp_critic = critic_output["loss_grasp_critic"]
+            optimizers["grasp_critic"].zero_grad()
+            loss_grasp_critic.backward()
+            grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
+                parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
+            ).item()
+            optimizers["grasp_critic"].step()
+            
+            # Add grasp critic info to training info
+            training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
+            training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
 
+        # Actor and temperature optimization (at specified frequency)
         if optimization_step % policy_update_freq == 0:
             for _ in range(policy_update_freq):
-                # Use the forward method for actor loss
-                loss_actor = policy.forward(forward_batch, model="actor")
-
+                # Actor optimization
+                actor_output = policy.forward(forward_batch, model="actor")
+                loss_actor = actor_output["loss_actor"]
                 optimizers["actor"].zero_grad()
                 loss_actor.backward()
-
-                # clip gradients
                 actor_grad_norm = torch.nn.utils.clip_grad_norm_(
                     parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
                 ).item()
-
                 optimizers["actor"].step()
-
+                
+                # Add actor info to training info
                 training_infos["loss_actor"] = loss_actor.item()
                 training_infos["actor_grad_norm"] = actor_grad_norm
 
-                # Temperature optimization using forward method
-                loss_temperature = policy.forward(forward_batch, model="temperature")
+                # Temperature optimization
+                temperature_output = policy.forward(forward_batch, model="temperature")
+                loss_temperature = temperature_output["loss_temperature"]
                 optimizers["temperature"].zero_grad()
                 loss_temperature.backward()
-
-                #  clip gradients
                 temp_grad_norm = torch.nn.utils.clip_grad_norm_(
                     parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
                 ).item()
-
                 optimizers["temperature"].step()
-
+                
+                # Add temperature info to training info
                 training_infos["loss_temperature"] = loss_temperature.item()
                 training_infos["temperature_grad_norm"] = temp_grad_norm
                 training_infos["temperature"] = policy.temperature
 
+                # Update temperature
                 policy.update_temperature()
 
-        # Check if it's time to push updated policy to actors
+        # Push policy to actors if needed
         if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
             push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
             last_time_policy_pushed = time.time()
 
+        # Update target networks
         policy.update_target_networks()
-        policy.update_grasp_target_networks()
 
         # Log training metrics at specified intervals
         if optimization_step % log_freq == 0:
@@ -727,7 +724,7 @@ def save_training_checkpoint(
     logging.info("Resume training")
 
 
-def make_optimizers_and_scheduler(cfg, policy: nn.Module):
+def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
     """
     Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
 
@@ -759,17 +756,20 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
         lr=cfg.policy.actor_lr,
     )
     optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
-    optimizer_grasp_critic = torch.optim.Adam(
-        params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr
-    )
+    
+    if cfg.policy.num_discrete_actions is not None:
+        optimizer_grasp_critic = torch.optim.Adam(
+            params=policy.grasp_critic.parameters(), lr=policy.critic_lr
+        )
     optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
     lr_scheduler = None
     optimizers = {
         "actor": optimizer_actor,
         "critic": optimizer_critic,
-        "grasp_critic": optimizer_grasp_critic,
         "temperature": optimizer_temperature,
     }
+    if cfg.policy.num_discrete_actions is not None:
+        optimizers["grasp_critic"] = optimizer_grasp_critic
     return optimizers, lr_scheduler
 
 

From 451a7b01db13d3e7dc4a111a1f1513f1a1eba396 Mon Sep 17 00:00:00 2001
From: AdilZouitine <adilzouitinegm@gmail.com>
Date: Tue, 1 Apr 2025 14:22:08 +0000
Subject: [PATCH 10/28] Add mock gripper support and enhance SAC policy action
 handling

- Introduced mock_gripper parameter in ManiskillEnvConfig to enable gripper simulation.
- Added ManiskillMockGripperWrapper to adjust action space for environments with discrete actions.
- Updated SACPolicy to compute continuous action dimensions correctly, ensuring compatibility with the new gripper setup.
- Refactored action handling in the training loop to accommodate the changes in action dimensions.
---
 lerobot/common/envs/configs.py                |  1 +
 lerobot/common/policies/sac/modeling_sac.py   | 18 +++--
 .../scripts/server/maniskill_manipulator.py   | 71 ++++++++++++-------
 3 files changed, 59 insertions(+), 31 deletions(-)

diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py
index 440512c3..a6eda93b 100644
--- a/lerobot/common/envs/configs.py
+++ b/lerobot/common/envs/configs.py
@@ -257,6 +257,7 @@ class ManiskillEnvConfig(EnvConfig):
     robot: str = "so100"  # This is a hack to make the robot config work
     video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
     wrapper: WrapperConfig = field(default_factory=WrapperConfig)
+    mock_gripper: bool = False
     features: dict[str, PolicyFeature] = field(
         default_factory=lambda: {
             "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index d0e8b25d..0c3d76d2 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -33,7 +33,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy
 from lerobot.common.policies.sac.configuration_sac import SACConfig
 from lerobot.common.policies.utils import get_device_from_parameters
 
-DISCRETE_DIMENSION_INDEX = -1   
+DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
 
 class SACPolicy(
     PreTrainedPolicy,
@@ -82,7 +82,7 @@ class SACPolicy(
         # Create a list of critic heads
         critic_heads = [
             CriticHead(
-                input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
+                input_dim=encoder_critic.output_dim + continuous_action_dim,
                 **asdict(config.critic_network_kwargs),
             )
             for _ in range(config.num_critics)
@@ -97,7 +97,7 @@ class SACPolicy(
         # Create target critic heads as deepcopies of the original critic heads
         target_critic_heads = [
             CriticHead(
-                input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
+                input_dim=encoder_critic.output_dim + continuous_action_dim,
                 **asdict(config.critic_network_kwargs),
             )
             for _ in range(config.num_critics)
@@ -117,7 +117,10 @@ class SACPolicy(
         self.grasp_critic = None
         self.grasp_critic_target = None
 
+        continuous_action_dim = config.output_features["action"].shape[0]
         if config.num_discrete_actions is not None:
+
+            continuous_action_dim -= 1
             # Create grasp critic
             self.grasp_critic = GraspCritic(
                 encoder=encoder_critic,
@@ -139,15 +142,16 @@ class SACPolicy(
             self.grasp_critic = torch.compile(self.grasp_critic)
             self.grasp_critic_target = torch.compile(self.grasp_critic_target)
 
+
         self.actor = Policy(
             encoder=encoder_actor,
             network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
-            action_dim=config.output_features["action"].shape[0],
+            action_dim=continuous_action_dim,
             encoder_is_shared=config.shared_encoder,
             **asdict(config.policy_kwargs),
         )
         if config.target_entropy is None:
-            config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2  # (-dim(A)/2)
+            config.target_entropy = -np.prod(continuous_action_dim) / 2  # (-dim(A)/2)
 
         # TODO (azouitine): Handle the case where the temparameter is a fixed
         # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@@ -275,7 +279,9 @@ class SACPolicy(
                     next_observations=next_observations,
                     done=done,
                 )
-            return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
+                return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
+
+            return {"loss_critic": loss_critic}
 
 
         if model == "actor":
diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py
index e10b8766..f4a89888 100644
--- a/lerobot/scripts/server/maniskill_manipulator.py
+++ b/lerobot/scripts/server/maniskill_manipulator.py
@@ -11,6 +11,10 @@ from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
 
 from lerobot.common.envs.configs import ManiskillEnvConfig
 from lerobot.configs import parser
+from lerobot.configs.train import TrainPipelineConfig
+from lerobot.common.policies.sac.configuration_sac import SACConfig
+from lerobot.common.policies.sac.modeling_sac import SACPolicy
+
 
 
 def preprocess_maniskill_observation(
@@ -152,6 +156,21 @@ class TimeLimitWrapper(gym.Wrapper):
         self.current_step = 0
         return super().reset(seed=seed, options=options)
 
+class ManiskillMockGripperWrapper(gym.Wrapper):
+    def __init__(self, env, nb_discrete_actions: int = 3):
+        super().__init__(env)
+        new_shape = env.action_space[0].shape[0] + 1
+        new_low = np.concatenate([env.action_space[0].low, [0]])
+        new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]])
+        action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,))
+        self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
+
+    def step(self, action):
+        action_agent, telop_action  = action  
+        real_action = action_agent[:-1]
+        final_action = (real_action, telop_action)
+        obs, reward, terminated, truncated, info = self.env.step(final_action)
+        return obs, reward, terminated, truncated, info 
 
 def make_maniskill(
     cfg: ManiskillEnvConfig,
@@ -197,40 +216,42 @@ def make_maniskill(
     env = ManiSkillCompat(env)
     env = ManiSkillActionWrapper(env)
     env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03)  # Scale actions for better control
+    if cfg.mock_gripper:
+        env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3)
 
     return env
 
 
-@parser.wrap()
-def main(cfg: ManiskillEnvConfig):
-    """Main function to run the ManiSkill environment."""
-    # Create the ManiSkill environment
-    env = make_maniskill(cfg, n_envs=1)
+# @parser.wrap()
+# def main(cfg: TrainPipelineConfig):
+#     """Main function to run the ManiSkill environment."""
+#     # Create the ManiSkill environment
+#     env = make_maniskill(cfg.env, n_envs=1)
 
-    # Reset the environment
-    obs, info = env.reset()
+#     # Reset the environment
+#     obs, info = env.reset()
 
-    # Run a simple interaction loop
-    sum_reward = 0
-    for i in range(100):
-        # Sample a random action
-        action = env.action_space.sample()
+#     # Run a simple interaction loop
+#     sum_reward = 0
+#     for i in range(100):
+#         # Sample a random action
+#         action = env.action_space.sample()
 
-        # Step the environment
-        start_time = time.perf_counter()
-        obs, reward, terminated, truncated, info = env.step(action)
-        step_time = time.perf_counter() - start_time
-        sum_reward += reward
-        # Log information
+#         # Step the environment
+#         start_time = time.perf_counter()
+#         obs, reward, terminated, truncated, info = env.step(action)
+#         step_time = time.perf_counter() - start_time
+#         sum_reward += reward
+#         # Log information
 
-        # Reset if episode terminated
-        if terminated or truncated:
-            logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
-            sum_reward = 0
-            obs, info = env.reset()
+#         # Reset if episode terminated
+#         if terminated or truncated:
+#             logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
+#             sum_reward = 0
+#             obs, info = env.reset()
 
-    # Close the environment
-    env.close()
+#     # Close the environment
+#     env.close()
 
 
 # if __name__ == "__main__":

From 699d374d895f8facdd2e5b66e379508e2466c97f Mon Sep 17 00:00:00 2001
From: AdilZouitine <adilzouitinegm@gmail.com>
Date: Tue, 1 Apr 2025 15:43:29 +0000
Subject: [PATCH 11/28] Refactor SACPolicy for improved readability and action
 dimension handling

- Cleaned up code formatting for better readability, including consistent spacing and removal of unnecessary blank lines.
- Consolidated continuous action dimension calculation to enhance clarity and maintainability.
- Simplified loss return statements in the forward method to improve code structure.
- Ensured grasp critic parameters are included conditionally based on configuration settings.
---
 lerobot/common/policies/sac/modeling_sac.py | 59 +++++++++++----------
 lerobot/scripts/server/learner_server.py    | 14 ++---
 2 files changed, 37 insertions(+), 36 deletions(-)

diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index 0c3d76d2..41ff7d8c 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -33,7 +33,8 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy
 from lerobot.common.policies.sac.configuration_sac import SACConfig
 from lerobot.common.policies.utils import get_device_from_parameters
 
-DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
+DISCRETE_DIMENSION_INDEX = -1  # Gripper is always the last dimension
+
 
 class SACPolicy(
     PreTrainedPolicy,
@@ -50,6 +51,10 @@ class SACPolicy(
         config.validate_features()
         self.config = config
 
+        continuous_action_dim = config.output_features["action"].shape[0]
+        if config.num_discrete_actions is not None:
+            continuous_action_dim -= 1
+
         if config.dataset_stats is not None:
             input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
             self.normalize_inputs = Normalize(
@@ -117,10 +122,7 @@ class SACPolicy(
         self.grasp_critic = None
         self.grasp_critic_target = None
 
-        continuous_action_dim = config.output_features["action"].shape[0]
         if config.num_discrete_actions is not None:
-
-            continuous_action_dim -= 1
             # Create grasp critic
             self.grasp_critic = GraspCritic(
                 encoder=encoder_critic,
@@ -142,7 +144,6 @@ class SACPolicy(
             self.grasp_critic = torch.compile(self.grasp_critic)
             self.grasp_critic_target = torch.compile(self.grasp_critic_target)
 
-
         self.actor = Policy(
             encoder=encoder_actor,
             network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
@@ -162,11 +163,14 @@ class SACPolicy(
         self.temperature = self.log_alpha.exp().item()
 
     def get_optim_params(self) -> dict:
-        return {
+        optim_params = {
             "actor": self.actor.parameters_to_optimize,
             "critic": self.critic_ensemble.parameters_to_optimize,
             "temperature": self.log_alpha,
         }
+        if self.config.num_discrete_actions is not None:
+            optim_params["grasp_critic"] = self.grasp_critic.parameters_to_optimize
+        return optim_params
 
     def reset(self):
         """Reset the policy"""
@@ -262,7 +266,7 @@ class SACPolicy(
             done: Tensor = batch["done"]
             next_observation_features: Tensor = batch.get("next_observation_feature")
 
-            loss_critic =  self.compute_loss_critic(
+            loss_critic = self.compute_loss_critic(
                 observations=observations,
                 actions=actions,
                 rewards=rewards,
@@ -283,18 +287,21 @@ class SACPolicy(
 
             return {"loss_critic": loss_critic}
 
-
         if model == "actor":
-            return {"loss_actor": self.compute_loss_actor(
-                observations=observations,
-                observation_features=observation_features,
-            )}
+            return {
+                "loss_actor": self.compute_loss_actor(
+                    observations=observations,
+                    observation_features=observation_features,
+                )
+            }
 
         if model == "temperature":
-            return {"loss_temperature": self.compute_loss_temperature(
-                observations=observations,
-                observation_features=observation_features,
-            )}
+            return {
+                "loss_temperature": self.compute_loss_temperature(
+                    observations=observations,
+                    observation_features=observation_features,
+                )
+            }
 
         raise ValueError(f"Unknown model type: {model}")
 
@@ -366,7 +373,7 @@ class SACPolicy(
             # In the buffer we have the full action space (continuous + discrete)
             # We need to split them before concatenating them in the critic forward
             actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
-            
+
         q_preds = self.critic_forward(
             observations=observations,
             actions=actions,
@@ -407,15 +414,13 @@ class SACPolicy(
             # For DQN, select actions using online network, evaluate with target network
             next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False)
             best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1)
-            
+
             # Get target Q-values from target network
             target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True)
-            
+
             # Use gather to select Q-values for best actions
             target_next_grasp_q = torch.gather(
-                target_next_grasp_qs, 
-                dim=1, 
-                index=best_next_grasp_action.unsqueeze(-1)
+                target_next_grasp_qs, dim=1, index=best_next_grasp_action.unsqueeze(-1)
             ).squeeze(-1)
 
         # Compute target Q-value with Bellman equation
@@ -423,13 +428,9 @@ class SACPolicy(
 
         # Get predicted Q-values for current observations
         predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False)
-        
+
         # Use gather to select Q-values for taken actions
-        predicted_grasp_q = torch.gather(
-            predicted_grasp_qs,
-            dim=1,
-            index=actions.unsqueeze(-1)
-        ).squeeze(-1)
+        predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions.unsqueeze(-1)).squeeze(-1)
 
         # Compute MSE loss between predicted and target Q-values
         grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q)
@@ -642,7 +643,7 @@ class GraspCritic(nn.Module):
         self,
         encoder: Optional[nn.Module],
         network: nn.Module,
-        output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that
+        output_dim: int = 3,  # TODO (azouitine): rename it number of discret acitons smth like that
         init_final: Optional[float] = None,
         encoder_is_shared: bool = False,
     ):
diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py
index 627a1a17..c57f83fc 100644
--- a/lerobot/scripts/server/learner_server.py
+++ b/lerobot/scripts/server/learner_server.py
@@ -394,7 +394,7 @@ def add_actor_information_and_train(
 
             # Use the forward method for critic loss (includes both main critic and grasp critic)
             critic_output = policy.forward(forward_batch, model="critic")
-            
+
             # Main critic optimization
             loss_critic = critic_output["loss_critic"]
             optimizers["critic"].zero_grad()
@@ -405,7 +405,7 @@ def add_actor_information_and_train(
             optimizers["critic"].step()
 
             # Grasp critic optimization (if available)
-            if "loss_grasp_critic" in critic_output and hasattr(policy, "grasp_critic"):
+            if "loss_grasp_critic" in critic_output:
                 loss_grasp_critic = critic_output["loss_grasp_critic"]
                 optimizers["grasp_critic"].zero_grad()
                 loss_grasp_critic.backward()
@@ -450,7 +450,7 @@ def add_actor_information_and_train(
 
         # Use the forward method for critic loss (includes both main critic and grasp critic)
         critic_output = policy.forward(forward_batch, model="critic")
-        
+
         # Main critic optimization
         loss_critic = critic_output["loss_critic"]
         optimizers["critic"].zero_grad()
@@ -475,7 +475,7 @@ def add_actor_information_and_train(
                 parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
             ).item()
             optimizers["grasp_critic"].step()
-            
+
             # Add grasp critic info to training info
             training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
             training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
@@ -492,7 +492,7 @@ def add_actor_information_and_train(
                     parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
                 ).item()
                 optimizers["actor"].step()
-                
+
                 # Add actor info to training info
                 training_infos["loss_actor"] = loss_actor.item()
                 training_infos["actor_grad_norm"] = actor_grad_norm
@@ -506,7 +506,7 @@ def add_actor_information_and_train(
                     parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
                 ).item()
                 optimizers["temperature"].step()
-                
+
                 # Add temperature info to training info
                 training_infos["loss_temperature"] = loss_temperature.item()
                 training_infos["temperature_grad_norm"] = temp_grad_norm
@@ -756,7 +756,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
         lr=cfg.policy.actor_lr,
     )
     optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
-    
+
     if cfg.policy.num_discrete_actions is not None:
         optimizer_grasp_critic = torch.optim.Adam(
             params=policy.grasp_critic.parameters(), lr=policy.critic_lr

From 0ed7ff142cd45992469cf2aee7a4f36f8571bb92 Mon Sep 17 00:00:00 2001
From: AdilZouitine <adilzouitinegm@gmail.com>
Date: Wed, 2 Apr 2025 15:50:39 +0000
Subject: [PATCH 12/28] Enhance SACPolicy and learner server for improved grasp
 critic integration

- Updated SACPolicy to conditionally compute grasp critic losses based on the presence of discrete actions.
- Refactored the forward method to handle grasp critic model selection and loss computation more clearly.
- Adjusted learner server to utilize optimized parameters for grasp critic during training.
- Improved action handling in the ManiskillMockGripperWrapper to accommodate both tuple and single action inputs.
---
 lerobot/common/policies/sac/modeling_sac.py   | 95 +++++++++++--------
 lerobot/scripts/server/learner_server.py      | 16 ++--
 .../scripts/server/maniskill_manipulator.py   | 11 ++-
 3 files changed, 72 insertions(+), 50 deletions(-)

diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index 41ff7d8c..af624592 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -52,8 +52,6 @@ class SACPolicy(
         self.config = config
 
         continuous_action_dim = config.output_features["action"].shape[0]
-        if config.num_discrete_actions is not None:
-            continuous_action_dim -= 1
 
         if config.dataset_stats is not None:
             input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
@@ -191,7 +189,7 @@ class SACPolicy(
 
         if self.config.num_discrete_actions is not None:
             discrete_action_value = self.grasp_critic(batch)
-            discrete_action = torch.argmax(discrete_action_value, dim=-1)
+            discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
             actions = torch.cat([actions, discrete_action], dim=-1)
 
         return actions
@@ -236,7 +234,7 @@ class SACPolicy(
     def forward(
         self,
         batch: dict[str, Tensor | dict[str, Tensor]],
-        model: Literal["actor", "critic", "temperature"] = "critic",
+        model: Literal["actor", "critic", "temperature", "grasp_critic"] = "critic",
     ) -> dict[str, Tensor]:
         """Compute the loss for the given model
 
@@ -275,18 +273,25 @@ class SACPolicy(
                 observation_features=observation_features,
                 next_observation_features=next_observation_features,
             )
-            if self.config.num_discrete_actions is not None:
-                loss_grasp_critic = self.compute_loss_grasp_critic(
-                    observations=observations,
-                    actions=actions,
-                    rewards=rewards,
-                    next_observations=next_observations,
-                    done=done,
-                )
-                return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
 
             return {"loss_critic": loss_critic}
 
+        if model == "grasp_critic" and self.config.num_discrete_actions is not None:
+            # Extract critic-specific components
+            rewards: Tensor = batch["reward"]
+            next_observations: dict[str, Tensor] = batch["next_state"]
+            done: Tensor = batch["done"]
+            next_observation_features: Tensor = batch.get("next_observation_feature")
+            loss_grasp_critic = self.compute_loss_grasp_critic(
+                observations=observations,
+                actions=actions,
+                rewards=rewards,
+                next_observations=next_observations,
+                done=done,
+                observation_features=observation_features,
+                next_observation_features=next_observation_features,
+            )
+            return {"loss_grasp_critic": loss_grasp_critic}
         if model == "actor":
             return {
                 "loss_actor": self.compute_loss_actor(
@@ -373,7 +378,6 @@ class SACPolicy(
             # In the buffer we have the full action space (continuous + discrete)
             # We need to split them before concatenating them in the critic forward
             actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
-
         q_preds = self.critic_forward(
             observations=observations,
             actions=actions,
@@ -407,30 +411,38 @@ class SACPolicy(
         # NOTE: We only want to keep the discrete action part
         # In the buffer we have the full action space (continuous + discrete)
         # We need to split them before concatenating them in the critic forward
-        actions: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:]
-        actions = actions.long()
+        actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
+        actions_discrete = actions_discrete.long()
 
         with torch.no_grad():
             # For DQN, select actions using online network, evaluate with target network
-            next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False)
-            best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1)
+            next_grasp_qs = self.grasp_critic_forward(
+                next_observations, use_target=False, observation_features=next_observation_features
+            )
+            best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True)
 
             # Get target Q-values from target network
-            target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True)
+            target_next_grasp_qs = self.grasp_critic_forward(
+                observations=next_observations,
+                use_target=True,
+                observation_features=next_observation_features,
+            )
 
             # Use gather to select Q-values for best actions
             target_next_grasp_q = torch.gather(
-                target_next_grasp_qs, dim=1, index=best_next_grasp_action.unsqueeze(-1)
+                target_next_grasp_qs, dim=1, index=best_next_grasp_action
             ).squeeze(-1)
 
         # Compute target Q-value with Bellman equation
         target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q
 
         # Get predicted Q-values for current observations
-        predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False)
+        predicted_grasp_qs = self.grasp_critic_forward(
+            observations=observations, use_target=False, observation_features=observation_features
+        )
 
         # Use gather to select Q-values for taken actions
-        predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions.unsqueeze(-1)).squeeze(-1)
+        predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions_discrete).squeeze(-1)
 
         # Compute MSE loss between predicted and target Q-values
         grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q)
@@ -642,49 +654,52 @@ class GraspCritic(nn.Module):
     def __init__(
         self,
         encoder: Optional[nn.Module],
-        network: nn.Module,
-        output_dim: int = 3,  # TODO (azouitine): rename it number of discret acitons smth like that
+        input_dim: int,
+        hidden_dims: list[int],
+        output_dim: int = 3,
+        activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
+        activate_final: bool = False,
+        dropout_rate: Optional[float] = None,
         init_final: Optional[float] = None,
-        encoder_is_shared: bool = False,
+        final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
     ):
         super().__init__()
         self.encoder = encoder
-        self.network = network
         self.output_dim = output_dim
 
-        # Find the last Linear layer's output dimension
-        for layer in reversed(network.net):
-            if isinstance(layer, nn.Linear):
-                out_features = layer.out_features
-                break
+        self.net = MLP(
+            input_dim=input_dim,
+            hidden_dims=hidden_dims,
+            activations=activations,
+            activate_final=activate_final,
+            dropout_rate=dropout_rate,
+            final_activation=final_activation,
+        )
 
-        self.parameters_to_optimize += list(self.network.parameters())
-
-        if self.encoder is not None and not encoder_is_shared:
-            self.parameters_to_optimize += list(self.encoder.parameters())
-
-        self.output_layer = nn.Linear(in_features=out_features, out_features=self.output_dim)
+        self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim)
         if init_final is not None:
             nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
             nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
         else:
             orthogonal_init()(self.output_layer.weight)
 
+        self.parameters_to_optimize = []
+        self.parameters_to_optimize += list(self.net.parameters())
         self.parameters_to_optimize += list(self.output_layer.parameters())
 
     def forward(
         self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
     ) -> torch.Tensor:
         device = get_device_from_parameters(self)
-        # Move each tensor in observations to device
+        # Move each tensor in observations to device by cloning first to avoid inplace operations
         observations = {k: v.to(device) for k, v in observations.items()}
         # Encode observations if encoder exists
         obs_enc = (
-            observation_features
+            observation_features.to(device)
             if observation_features is not None
             else (observations if self.encoder is None else self.encoder(observations))
         )
-        return self.output_layer(self.network(obs_enc))
+        return self.output_layer(self.net(obs_enc))
 
 
 class Policy(nn.Module):
diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py
index c57f83fc..ce9a1b41 100644
--- a/lerobot/scripts/server/learner_server.py
+++ b/lerobot/scripts/server/learner_server.py
@@ -405,12 +405,13 @@ def add_actor_information_and_train(
             optimizers["critic"].step()
 
             # Grasp critic optimization (if available)
-            if "loss_grasp_critic" in critic_output:
-                loss_grasp_critic = critic_output["loss_grasp_critic"]
+            if policy.config.num_discrete_actions is not None:
+                discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
+                loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
                 optimizers["grasp_critic"].zero_grad()
                 loss_grasp_critic.backward()
                 grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
-                    parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
+                    parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
                 )
                 optimizers["grasp_critic"].step()
 
@@ -467,12 +468,13 @@ def add_actor_information_and_train(
         }
 
         # Grasp critic optimization (if available)
-        if "loss_grasp_critic" in critic_output:
-            loss_grasp_critic = critic_output["loss_grasp_critic"]
+        if policy.config.num_discrete_actions is not None:
+            discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
+            loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
             optimizers["grasp_critic"].zero_grad()
             loss_grasp_critic.backward()
             grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
-                parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
+                parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
             ).item()
             optimizers["grasp_critic"].step()
 
@@ -759,7 +761,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
 
     if cfg.policy.num_discrete_actions is not None:
         optimizer_grasp_critic = torch.optim.Adam(
-            params=policy.grasp_critic.parameters(), lr=policy.critic_lr
+            params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr
         )
     optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
     lr_scheduler = None
diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py
index f4a89888..b5c181c1 100644
--- a/lerobot/scripts/server/maniskill_manipulator.py
+++ b/lerobot/scripts/server/maniskill_manipulator.py
@@ -16,7 +16,6 @@ from lerobot.common.policies.sac.configuration_sac import SACConfig
 from lerobot.common.policies.sac.modeling_sac import SACPolicy
 
 
-
 def preprocess_maniskill_observation(
     observations: dict[str, np.ndarray],
 ) -> dict[str, torch.Tensor]:
@@ -156,6 +155,7 @@ class TimeLimitWrapper(gym.Wrapper):
         self.current_step = 0
         return super().reset(seed=seed, options=options)
 
+
 class ManiskillMockGripperWrapper(gym.Wrapper):
     def __init__(self, env, nb_discrete_actions: int = 3):
         super().__init__(env)
@@ -166,11 +166,16 @@ class ManiskillMockGripperWrapper(gym.Wrapper):
         self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
 
     def step(self, action):
-        action_agent, telop_action  = action  
+        if isinstance(action, tuple):
+            action_agent, telop_action = action
+        else:
+            telop_action = 0
+            action_agent = action
         real_action = action_agent[:-1]
         final_action = (real_action, telop_action)
         obs, reward, terminated, truncated, info = self.env.step(final_action)
-        return obs, reward, terminated, truncated, info 
+        return obs, reward, terminated, truncated, info
+
 
 def make_maniskill(
     cfg: ManiskillEnvConfig,

From 51f1625c2004cdb8adf50c5c862a1c778d746c97 Mon Sep 17 00:00:00 2001
From: AdilZouitine <adilzouitinegm@gmail.com>
Date: Thu, 3 Apr 2025 07:44:46 +0000
Subject: [PATCH 13/28] Enhance SACPolicy to support shared encoder and
 optimize action selection

- Cached encoder output in select_action method to reduce redundant computations.
- Updated action selection and grasp critic calls to utilize cached encoder features when available.
---
 lerobot/common/policies/sac/modeling_sac.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index af624592..2246bf8c 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -81,6 +81,7 @@ class SACPolicy(
         else:
             encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
             encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
+        self.shared_encoder = config.shared_encoder
 
         # Create a list of critic heads
         critic_heads = [
@@ -184,11 +185,16 @@ class SACPolicy(
     @torch.no_grad()
     def select_action(self, batch: dict[str, Tensor]) -> Tensor:
         """Select action for inference/evaluation"""
-        actions, _, _ = self.actor(batch)
+        # We cached the encoder output to avoid recomputing it
+        observations_features = None
+        if self.shared_encoder and self.actor.encoder is not None:
+            observations_features = self.actor.encoder(batch)
+
+        actions, _, _ = self.actor(batch, observations_features)
         actions = self.unnormalize_outputs({"action": actions})["action"]
 
         if self.config.num_discrete_actions is not None:
-            discrete_action_value = self.grasp_critic(batch)
+            discrete_action_value = self.grasp_critic(batch, observations_features)
             discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
             actions = torch.cat([actions, discrete_action], dim=-1)
 

From 38a8dbd9c9a0fe5bf9fb0244be068869ea0bc1c1 Mon Sep 17 00:00:00 2001
From: AdilZouitine <adilzouitinegm@gmail.com>
Date: Thu, 3 Apr 2025 14:23:50 +0000
Subject: [PATCH 14/28] Enhance SAC configuration and replay buffer with
 asynchronous prefetching support

- Added async_prefetch parameter to SACConfig for improved buffer management.
- Implemented get_iterator method in ReplayBuffer to support asynchronous prefetching of batches.
- Updated learner_server to utilize the new iterator for online and offline sampling, enhancing training efficiency.
---
 .../common/policies/sac/configuration_sac.py  |   4 +-
 lerobot/scripts/server/buffer.py              | 576 ++++--------------
 lerobot/scripts/server/learner_server.py      |  34 +-
 3 files changed, 132 insertions(+), 482 deletions(-)

diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py
index ae38b1c5..3d01f47c 100644
--- a/lerobot/common/policies/sac/configuration_sac.py
+++ b/lerobot/common/policies/sac/configuration_sac.py
@@ -42,8 +42,6 @@ class CriticNetworkConfig:
     final_activation: str | None = None
 
 
-
-
 @dataclass
 class ActorNetworkConfig:
     hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
@@ -94,6 +92,7 @@ class SACConfig(PreTrainedConfig):
         online_env_seed: Seed for the online environment.
         online_buffer_capacity: Capacity of the online replay buffer.
         offline_buffer_capacity: Capacity of the offline replay buffer.
+        async_prefetch: Whether to use asynchronous prefetching for the buffers.
         online_step_before_learning: Number of steps before learning starts.
         policy_update_freq: Frequency of policy updates.
         discount: Discount factor for the SAC algorithm.
@@ -154,6 +153,7 @@ class SACConfig(PreTrainedConfig):
     online_env_seed: int = 10000
     online_buffer_capacity: int = 100000
     offline_buffer_capacity: int = 100000
+    async_prefetch: bool = False
     online_step_before_learning: int = 100
     policy_update_freq: int = 1
 
diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py
index 2af3995e..c8f85372 100644
--- a/lerobot/scripts/server/buffer.py
+++ b/lerobot/scripts/server/buffer.py
@@ -345,6 +345,109 @@ class ReplayBuffer:
             truncated=batch_truncateds,
         )
 
+    def get_iterator(
+        self,
+        batch_size: int,
+        async_prefetch: bool = True,
+        queue_size: int = 2,
+    ):
+        """
+        Creates an infinite iterator that yields batches of transitions.
+        Will automatically restart when internal iterator is exhausted.
+
+        Args:
+            batch_size (int): Size of batches to sample
+            async_prefetch (bool): Whether to use asynchronous prefetching with threads (default: True)
+            queue_size (int): Number of batches to prefetch (default: 2)
+
+        Yields:
+            BatchTransition: Batched transitions
+        """
+        while True:  # Create an infinite loop
+            if async_prefetch:
+                # Get the standard iterator
+                iterator = self._get_async_iterator(queue_size=queue_size, batch_size=batch_size)
+            else:
+                iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size)
+
+            # Yield all items from the iterator
+            try:
+                yield from iterator
+            except StopIteration:
+                # Just continue the outer loop to create a new iterator
+                pass
+
+    def _get_async_iterator(self, batch_size: int, queue_size: int = 2):
+        """
+        Creates an iterator that prefetches batches in a background thread.
+
+        Args:
+            queue_size (int): Number of batches to prefetch (default: 2)
+            batch_size (int): Size of batches to sample (default: 128)
+
+        Yields:
+            BatchTransition: Prefetched batch transitions
+        """
+        import threading
+        import queue
+
+        # Use thread-safe queue
+        data_queue = queue.Queue(maxsize=queue_size)
+        running = [True]  # Use list to allow modification in nested function
+
+        def prefetch_worker():
+            while running[0]:
+                try:
+                    # Sample data and add to queue
+                    data = self.sample(batch_size)
+                    data_queue.put(data, block=True, timeout=0.5)
+                except queue.Full:
+                    continue
+                except Exception as e:
+                    print(f"Prefetch error: {e}")
+                    break
+
+        # Start prefetching thread
+        thread = threading.Thread(target=prefetch_worker, daemon=True)
+        thread.start()
+
+        try:
+            while running[0]:
+                try:
+                    yield data_queue.get(block=True, timeout=0.5)
+                except queue.Empty:
+                    if not thread.is_alive():
+                        break
+        finally:
+            # Clean up
+            running[0] = False
+            thread.join(timeout=1.0)
+
+    def _get_naive_iterator(self, batch_size: int, queue_size: int = 2):
+        """
+        Creates a simple non-threaded iterator that yields batches.
+
+        Args:
+            batch_size (int): Size of batches to sample
+            queue_size (int): Number of initial batches to prefetch
+
+        Yields:
+            BatchTransition: Batch transitions
+        """
+        import collections
+
+        queue = collections.deque()
+
+        def enqueue(n):
+            for _ in range(n):
+                data = self.sample(batch_size)
+                queue.append(data)
+
+        enqueue(queue_size)
+        while queue:
+            yield queue.popleft()
+            enqueue(1)
+
     @classmethod
     def from_lerobot_dataset(
         cls,
@@ -710,475 +813,4 @@ def concatenate_batch_transitions(
 
 
 if __name__ == "__main__":
-    from tempfile import TemporaryDirectory
-
-    # ===== Test 1: Create and use a synthetic ReplayBuffer =====
-    print("Testing synthetic ReplayBuffer...")
-
-    # Create sample data dimensions
-    batch_size = 32
-    state_dims = {"observation.image": (3, 84, 84), "observation.state": (10,)}
-    action_dim = (6,)
-
-    # Create a buffer
-    buffer = ReplayBuffer(
-        capacity=1000,
-        device="cpu",
-        state_keys=list(state_dims.keys()),
-        use_drq=True,
-        storage_device="cpu",
-    )
-
-    # Add some random transitions
-    for i in range(100):
-        # Create dummy transition data
-        state = {
-            "observation.image": torch.rand(1, 3, 84, 84),
-            "observation.state": torch.rand(1, 10),
-        }
-        action = torch.rand(1, 6)
-        reward = 0.5
-        next_state = {
-            "observation.image": torch.rand(1, 3, 84, 84),
-            "observation.state": torch.rand(1, 10),
-        }
-        done = False if i < 99 else True
-        truncated = False
-
-        buffer.add(
-            state=state,
-            action=action,
-            reward=reward,
-            next_state=next_state,
-            done=done,
-            truncated=truncated,
-        )
-
-    # Test sampling
-    batch = buffer.sample(batch_size)
-    print(f"Buffer size: {len(buffer)}")
-    print(
-        f"Sampled batch state shapes: {batch['state']['observation.image'].shape}, {batch['state']['observation.state'].shape}"
-    )
-    print(f"Sampled batch action shape: {batch['action'].shape}")
-    print(f"Sampled batch reward shape: {batch['reward'].shape}")
-    print(f"Sampled batch done shape: {batch['done'].shape}")
-    print(f"Sampled batch truncated shape: {batch['truncated'].shape}")
-
-    # ===== Test for state-action-reward alignment =====
-    print("\nTesting state-action-reward alignment...")
-
-    # Create a buffer with controlled transitions where we know the relationships
-    aligned_buffer = ReplayBuffer(
-        capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu"
-    )
-
-    # Create transitions with known relationships
-    # - Each state has a unique signature value
-    # - Action is 2x the state signature
-    # - Reward is 3x the state signature
-    # - Next state is signature + 0.01 (unless at episode end)
-    for i in range(100):
-        # Create a state with a signature value that encodes the transition number
-        signature = float(i) / 100.0
-        state = {"state_value": torch.tensor([[signature]]).float()}
-
-        # Action is 2x the signature
-        action = torch.tensor([[2.0 * signature]]).float()
-
-        # Reward is 3x the signature
-        reward = 3.0 * signature
-
-        # Next state is signature + 0.01, unless end of episode
-        # End episode every 10 steps
-        is_end = (i + 1) % 10 == 0
-
-        if is_end:
-            # At episode boundaries, next_state repeats current state (as per your implementation)
-            next_state = {"state_value": torch.tensor([[signature]]).float()}
-            done = True
-        else:
-            # Within episodes, next_state has signature + 0.01
-            next_signature = float(i + 1) / 100.0
-            next_state = {"state_value": torch.tensor([[next_signature]]).float()}
-            done = False
-
-        aligned_buffer.add(state, action, reward, next_state, done, False)
-
-    # Sample from this buffer
-    aligned_batch = aligned_buffer.sample(50)
-
-    # Verify alignments in sampled batch
-    correct_relationships = 0
-    total_checks = 0
-
-    # For each transition in the batch
-    for i in range(50):
-        # Extract signature from state
-        state_sig = aligned_batch["state"]["state_value"][i].item()
-
-        # Check action is 2x signature (within reasonable precision)
-        action_val = aligned_batch["action"][i].item()
-        action_check = abs(action_val - 2.0 * state_sig) < 1e-4
-
-        # Check reward is 3x signature (within reasonable precision)
-        reward_val = aligned_batch["reward"][i].item()
-        reward_check = abs(reward_val - 3.0 * state_sig) < 1e-4
-
-        # Check next_state relationship matches our pattern
-        next_state_sig = aligned_batch["next_state"]["state_value"][i].item()
-        is_done = aligned_batch["done"][i].item() > 0.5
-
-        # Calculate expected next_state value based on done flag
-        if is_done:
-            # For episodes that end, next_state should equal state
-            next_state_check = abs(next_state_sig - state_sig) < 1e-4
-        else:
-            # For continuing episodes, check if next_state is approximately state + 0.01
-            # We need to be careful because we don't know the original index
-            # So we check if the increment is roughly 0.01
-            next_state_check = (
-                abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4
-            )
-
-        # Count correct relationships
-        if action_check:
-            correct_relationships += 1
-        if reward_check:
-            correct_relationships += 1
-        if next_state_check:
-            correct_relationships += 1
-
-        total_checks += 3
-
-    alignment_accuracy = 100.0 * correct_relationships / total_checks
-    print(f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%")
-    if alignment_accuracy > 99.0:
-        print("✅ All relationships verified! Buffer maintains correct temporal relationships.")
-    else:
-        print("⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues.")
-
-        # Print some debug information about failures
-        print("\nDebug information for failed checks:")
-        for i in range(5):  # Print first 5 transitions for debugging
-            state_sig = aligned_batch["state"]["state_value"][i].item()
-            action_val = aligned_batch["action"][i].item()
-            reward_val = aligned_batch["reward"][i].item()
-            next_state_sig = aligned_batch["next_state"]["state_value"][i].item()
-            is_done = aligned_batch["done"][i].item() > 0.5
-
-            print(f"Transition {i}:")
-            print(f"  State: {state_sig:.6f}")
-            print(f"  Action: {action_val:.6f} (expected: {2.0 * state_sig:.6f})")
-            print(f"  Reward: {reward_val:.6f} (expected: {3.0 * state_sig:.6f})")
-            print(f"  Done: {is_done}")
-            print(f"  Next state: {next_state_sig:.6f}")
-
-            # Calculate expected next state
-            if is_done:
-                expected_next = state_sig
-            else:
-                # This approximation might not be perfect
-                state_idx = round(state_sig * 100)
-                expected_next = (state_idx + 1) / 100.0
-
-            print(f"  Expected next state: {expected_next:.6f}")
-            print()
-
-    # ===== Test 2: Convert to LeRobotDataset and back =====
-    with TemporaryDirectory() as temp_dir:
-        print("\nTesting conversion to LeRobotDataset and back...")
-        # Convert buffer to dataset
-        repo_id = "test/replay_buffer_conversion"
-        # Create a subdirectory to avoid the "directory exists" error
-        dataset_dir = os.path.join(temp_dir, "dataset1")
-        dataset = buffer.to_lerobot_dataset(repo_id=repo_id, root=dataset_dir)
-
-        print(f"Dataset created with {len(dataset)} frames")
-        print(f"Dataset features: {list(dataset.features.keys())}")
-
-        # Check a random sample from the dataset
-        sample = dataset[0]
-        print(
-            f"Dataset sample types: {[(k, type(v)) for k, v in sample.items() if k.startswith('observation')]}"
-        )
-
-        # Convert dataset back to buffer
-        reconverted_buffer = ReplayBuffer.from_lerobot_dataset(
-            dataset, state_keys=list(state_dims.keys()), device="cpu"
-        )
-
-        print(f"Reconverted buffer size: {len(reconverted_buffer)}")
-
-        # Sample from the reconverted buffer
-        reconverted_batch = reconverted_buffer.sample(batch_size)
-        print(
-            f"Reconverted batch state shapes: {reconverted_batch['state']['observation.image'].shape}, {reconverted_batch['state']['observation.state'].shape}"
-        )
-
-        # Verify consistency before and after conversion
-        original_states = batch["state"]["observation.image"].mean().item()
-        reconverted_states = reconverted_batch["state"]["observation.image"].mean().item()
-        print(f"Original buffer state mean: {original_states:.4f}")
-        print(f"Reconverted buffer state mean: {reconverted_states:.4f}")
-
-        if abs(original_states - reconverted_states) < 1.0:
-            print("Values are reasonably similar - conversion works as expected")
-        else:
-            print("WARNING: Significant difference between original and reconverted values")
-
-    print("\nAll previous tests completed!")
-
-    # ===== Test for memory optimization =====
-    print("\n===== Testing Memory Optimization =====")
-
-    # Create two buffers, one with memory optimization and one without
-    standard_buffer = ReplayBuffer(
-        capacity=1000,
-        device="cpu",
-        state_keys=["observation.image", "observation.state"],
-        storage_device="cpu",
-        optimize_memory=False,
-        use_drq=True,
-    )
-
-    optimized_buffer = ReplayBuffer(
-        capacity=1000,
-        device="cpu",
-        state_keys=["observation.image", "observation.state"],
-        storage_device="cpu",
-        optimize_memory=True,
-        use_drq=True,
-    )
-
-    # Generate sample data with larger state dimensions for better memory impact
-    print("Generating test data...")
-    num_episodes = 10
-    steps_per_episode = 50
-    total_steps = num_episodes * steps_per_episode
-
-    for episode in range(num_episodes):
-        for step in range(steps_per_episode):
-            # Index in the overall sequence
-            i = episode * steps_per_episode + step
-
-            # Create state with identifiable values
-            img = torch.ones((3, 84, 84)) * (i / total_steps)
-            state_vec = torch.ones((10,)) * (i / total_steps)
-
-            state = {
-                "observation.image": img.unsqueeze(0),
-                "observation.state": state_vec.unsqueeze(0),
-            }
-
-            # Create next state (i+1 or same as current if last in episode)
-            is_last_step = step == steps_per_episode - 1
-
-            if is_last_step:
-                # At episode end, next state = current state
-                next_img = img.clone()
-                next_state_vec = state_vec.clone()
-                done = True
-                truncated = False
-            else:
-                # Within episode, next state has incremented value
-                next_val = (i + 1) / total_steps
-                next_img = torch.ones((3, 84, 84)) * next_val
-                next_state_vec = torch.ones((10,)) * next_val
-                done = False
-                truncated = False
-
-            next_state = {
-                "observation.image": next_img.unsqueeze(0),
-                "observation.state": next_state_vec.unsqueeze(0),
-            }
-
-            # Action and reward
-            action = torch.tensor([[i / total_steps]])
-            reward = float(i / total_steps)
-
-            # Add to both buffers
-            standard_buffer.add(state, action, reward, next_state, done, truncated)
-            optimized_buffer.add(state, action, reward, next_state, done, truncated)
-
-    # Verify episode boundaries with our simplified approach
-    print("\nVerifying simplified memory optimization...")
-
-    # Test with a new buffer with a small sequence
-    test_buffer = ReplayBuffer(
-        capacity=20,
-        device="cpu",
-        state_keys=["value"],
-        storage_device="cpu",
-        optimize_memory=True,
-        use_drq=False,
-    )
-
-    # Add a simple sequence with known episode boundaries
-    for i in range(20):
-        val = float(i)
-        state = {"value": torch.tensor([[val]]).float()}
-        next_val = float(i + 1) if i % 5 != 4 else val  # Episode ends every 5 steps
-        next_state = {"value": torch.tensor([[next_val]]).float()}
-
-        # Set done=True at every 5th step
-        done = (i % 5) == 4
-        action = torch.tensor([[0.0]])
-        reward = 1.0
-        truncated = False
-
-        test_buffer.add(state, action, reward, next_state, done, truncated)
-
-    # Get sequential batch for verification
-    sequential_batch_size = test_buffer.size
-    all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device)
-
-    # Get state tensors
-    batch_state = {"value": test_buffer.states["value"][all_indices].to(test_buffer.device)}
-
-    # Get next_state using memory-optimized approach (simply index+1)
-    next_indices = (all_indices + 1) % test_buffer.capacity
-    batch_next_state = {"value": test_buffer.states["value"][next_indices].to(test_buffer.device)}
-
-    # Get other tensors
-    batch_dones = test_buffer.dones[all_indices].to(test_buffer.device)
-
-    # Print sequential values
-    print("State, Next State, Done (Sequential values with simplified optimization):")
-    state_values = batch_state["value"].squeeze().tolist()
-    next_values = batch_next_state["value"].squeeze().tolist()
-    done_flags = batch_dones.tolist()
-
-    # Print all values
-    for i in range(len(state_values)):
-        print(f"  {state_values[i]:.1f} → {next_values[i]:.1f}, Done: {done_flags[i]}")
-
-    # Explain the memory optimization tradeoff
-    print("\nWith simplified memory optimization:")
-    print("- We always use the next state in the buffer (index+1) as next_state")
-    print("- For terminal states, this means using the first state of the next episode")
-    print("- This is a common tradeoff in RL implementations for memory efficiency")
-    print("- Since we track done flags, the algorithm can handle these transitions correctly")
-
-    # Test random sampling
-    print("\nVerifying random sampling with simplified memory optimization...")
-    random_samples = test_buffer.sample(20)  # Sample all transitions
-
-    # Extract values
-    random_state_values = random_samples["state"]["value"].squeeze().tolist()
-    random_next_values = random_samples["next_state"]["value"].squeeze().tolist()
-    random_done_flags = random_samples["done"].bool().tolist()
-
-    # Print a few samples
-    print("Random samples - State, Next State, Done (First 10):")
-    for i in range(10):
-        print(f"  {random_state_values[i]:.1f} → {random_next_values[i]:.1f}, Done: {random_done_flags[i]}")
-
-    # Calculate memory savings
-    # Assume optimized_buffer and standard_buffer have already been initialized and filled
-    std_mem = (
-        sum(
-            standard_buffer.states[key].nelement() * standard_buffer.states[key].element_size()
-            for key in standard_buffer.states
-        )
-        * 2
-    )
-    opt_mem = sum(
-        optimized_buffer.states[key].nelement() * optimized_buffer.states[key].element_size()
-        for key in optimized_buffer.states
-    )
-
-    savings_percent = (std_mem - opt_mem) / std_mem * 100
-
-    print("\nMemory optimization result:")
-    print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB")
-    print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB")
-    print(f"- Memory savings for state tensors: {savings_percent:.1f}%")
-
-    print("\nAll memory optimization tests completed!")
-
-    # # ===== Test real dataset conversion =====
-    # print("\n===== Testing Real LeRobotDataset Conversion =====")
-    # try:
-    #     # Try to use a real dataset if available
-    #     dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small"
-    #     dataset = LeRobotDataset(repo_id=dataset_name)
-
-    #     # Print available keys to debug
-    #     sample = dataset[0]
-    #     print("Available keys in dataset:", list(sample.keys()))
-
-    #     # Check for required keys
-    #     if "action" not in sample or "next.reward" not in sample:
-    #         print("Dataset missing essential keys. Cannot convert.")
-    #         raise ValueError("Missing required keys in dataset")
-
-    #     # Auto-detect appropriate state keys
-    #     image_keys = []
-    #     state_keys = []
-    #     for k, v in sample.items():
-    #         # Skip metadata keys and action/reward keys
-    #         if k in {
-    #             "index",
-    #             "episode_index",
-    #             "frame_index",
-    #             "timestamp",
-    #             "task_index",
-    #             "action",
-    #             "next.reward",
-    #             "next.done",
-    #         }:
-    #             continue
-
-    #         # Infer key type from tensor shape
-    #         if isinstance(v, torch.Tensor):
-    #             if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1):
-    #                 # Likely an image (channels, height, width)
-    #                 image_keys.append(k)
-    #             else:
-    #                 # Likely state or other vector
-    #                 state_keys.append(k)
-
-    #     print(f"Detected image keys: {image_keys}")
-    #     print(f"Detected state keys: {state_keys}")
-
-    #     if not image_keys and not state_keys:
-    #         print("No usable keys found in dataset, skipping further tests")
-    #         raise ValueError("No usable keys found in dataset")
-
-    #     # Test with standard and memory-optimized buffers
-    #     for optimize_memory in [False, True]:
-    #         buffer_type = "Standard" if not optimize_memory else "Memory-optimized"
-    #         print(f"\nTesting {buffer_type} buffer with real dataset...")
-
-    #         # Convert to ReplayBuffer with detected keys
-    #         replay_buffer = ReplayBuffer.from_lerobot_dataset(
-    #             lerobot_dataset=dataset,
-    #             state_keys=image_keys + state_keys,
-    #             device="cpu",
-    #             optimize_memory=optimize_memory,
-    #         )
-    #         print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}")
-
-    #         # Test sampling
-    #         real_batch = replay_buffer.sample(32)
-    #         print(f"Sampled batch from real dataset ({buffer_type}), state shapes:")
-    #         for key in real_batch["state"]:
-    #             print(f"  {key}: {real_batch['state'][key].shape}")
-
-    #         # Convert back to LeRobotDataset
-    #         with TemporaryDirectory() as temp_dir:
-    #             dataset_name = f"test/real_dataset_converted_{buffer_type}"
-    #             replay_buffer_converted = replay_buffer.to_lerobot_dataset(
-    #                 repo_id=dataset_name,
-    #                 root=os.path.join(temp_dir, f"dataset_{buffer_type}"),
-    #             )
-    #             print(
-    #                 f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames"
-    #             )
-
-    # except Exception as e:
-    #     print(f"Real dataset test failed: {e}")
-    #     print("This is expected if running offline or if the dataset is not available.")
-
-    # print("\nAll tests completed!")
+    pass  # All test code is currently commented out
diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py
index ce9a1b41..08baa6ea 100644
--- a/lerobot/scripts/server/learner_server.py
+++ b/lerobot/scripts/server/learner_server.py
@@ -269,6 +269,7 @@ def add_actor_information_and_train(
     policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
     saving_checkpoint = cfg.save_checkpoint
     online_steps = cfg.policy.online_steps
+    async_prefetch = cfg.policy.async_prefetch
 
     # Initialize logging for multiprocessing
     if not use_threads(cfg):
@@ -326,6 +327,9 @@ def add_actor_information_and_train(
     if cfg.dataset is not None:
         dataset_repo_id = cfg.dataset.repo_id
 
+    # Initialize iterators
+    online_iterator = None
+    offline_iterator = None
     # NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
     while True:
         # Exit the training loop if shutdown is requested
@@ -359,16 +363,29 @@ def add_actor_information_and_train(
         if len(replay_buffer) < online_step_before_learning:
             continue
 
+        if online_iterator is None:
+            logging.debug("[LEARNER] Initializing online replay buffer iterator")
+            online_iterator = replay_buffer.get_iterator(
+                batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
+            )
+
+        if offline_replay_buffer is not None and offline_iterator is None:
+            logging.debug("[LEARNER] Initializing offline replay buffer iterator")
+            offline_iterator = offline_replay_buffer.get_iterator(
+                batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
+            )
+
         logging.debug("[LEARNER] Starting optimization loop")
         time_for_one_optimization_step = time.time()
         for _ in range(utd_ratio - 1):
-            batch = replay_buffer.sample(batch_size=batch_size)
+            # Sample from the iterators
+            batch = next(online_iterator)
 
-            if dataset_repo_id is not None:
-                batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
-                batch = concatenate_batch_transitions(
-                    left_batch_transitions=batch, right_batch_transition=batch_offline
-                )
+        if dataset_repo_id is not None:
+            batch_offline = next(offline_iterator)
+            batch = concatenate_batch_transitions(
+                left_batch_transitions=batch, right_batch_transition=batch_offline
+            )
 
             actions = batch["action"]
             rewards = batch["reward"]
@@ -418,10 +435,11 @@ def add_actor_information_and_train(
             # Update target networks
             policy.update_target_networks()
 
-        batch = replay_buffer.sample(batch_size=batch_size)
+            # Sample for the last update in the UTD ratio
+            batch = next(online_iterator)
 
         if dataset_repo_id is not None:
-            batch_offline = offline_replay_buffer.sample(batch_size=batch_size)
+            batch_offline = next(offline_iterator)
             batch = concatenate_batch_transitions(
                 left_batch_transitions=batch, right_batch_transition=batch_offline
             )

From e86fe66dbd392fbf087e0ac74f8e23c1262c272b Mon Sep 17 00:00:00 2001
From: AdilZouitine <adilzouitinegm@gmail.com>
Date: Thu, 3 Apr 2025 16:05:29 +0000
Subject: [PATCH 15/28] fix indentation issue

---
 lerobot/scripts/server/learner_server.py | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py
index 08baa6ea..65b1d9b8 100644
--- a/lerobot/scripts/server/learner_server.py
+++ b/lerobot/scripts/server/learner_server.py
@@ -381,11 +381,11 @@ def add_actor_information_and_train(
             # Sample from the iterators
             batch = next(online_iterator)
 
-        if dataset_repo_id is not None:
-            batch_offline = next(offline_iterator)
-            batch = concatenate_batch_transitions(
-                left_batch_transitions=batch, right_batch_transition=batch_offline
-            )
+            if dataset_repo_id is not None:
+                batch_offline = next(offline_iterator)
+                batch = concatenate_batch_transitions(
+                    left_batch_transitions=batch, right_batch_transition=batch_offline
+                )
 
             actions = batch["action"]
             rewards = batch["reward"]
@@ -435,8 +435,8 @@ def add_actor_information_and_train(
             # Update target networks
             policy.update_target_networks()
 
-            # Sample for the last update in the UTD ratio
-            batch = next(online_iterator)
+        # Sample for the last update in the UTD ratio
+        batch = next(online_iterator)
 
         if dataset_repo_id is not None:
             batch_offline = next(offline_iterator)

From 037ecae9e0b89d07dbc3da30e177cd5d7c27ed70 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Fri, 4 Apr 2025 07:59:22 +0000
Subject: [PATCH 16/28] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 lerobot/scripts/server/buffer.py                | 3 +--
 lerobot/scripts/server/maniskill_manipulator.py | 6 ------
 2 files changed, 1 insertion(+), 8 deletions(-)

diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py
index c8f85372..8947f6d9 100644
--- a/lerobot/scripts/server/buffer.py
+++ b/lerobot/scripts/server/buffer.py
@@ -15,7 +15,6 @@
 # limitations under the License.
 import functools
 import io
-import os
 import pickle
 from typing import Any, Callable, Optional, Sequence, TypedDict
 
@@ -388,8 +387,8 @@ class ReplayBuffer:
         Yields:
             BatchTransition: Prefetched batch transitions
         """
-        import threading
         import queue
+        import threading
 
         # Use thread-safe queue
         data_queue = queue.Queue(maxsize=queue_size)
diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py
index b5c181c1..03a7ec10 100644
--- a/lerobot/scripts/server/maniskill_manipulator.py
+++ b/lerobot/scripts/server/maniskill_manipulator.py
@@ -1,5 +1,3 @@
-import logging
-import time
 from typing import Any
 
 import einops
@@ -10,10 +8,6 @@ from mani_skill.utils.wrappers.record import RecordEpisode
 from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
 
 from lerobot.common.envs.configs import ManiskillEnvConfig
-from lerobot.configs import parser
-from lerobot.configs.train import TrainPipelineConfig
-from lerobot.common.policies.sac.configuration_sac import SACConfig
-from lerobot.common.policies.sac.modeling_sac import SACPolicy
 
 
 def preprocess_maniskill_observation(

From 7741526ce42f2017db33841895d36e2c95ac875c Mon Sep 17 00:00:00 2001
From: AdilZouitine <adilzouitinegm@gmail.com>
Date: Fri, 4 Apr 2025 14:29:38 +0000
Subject: [PATCH 17/28] fix caching

---
 lerobot/common/policies/sac/modeling_sac.py | 224 ++++++++++----------
 lerobot/scripts/server/learner_server.py    |  12 +-
 2 files changed, 117 insertions(+), 119 deletions(-)

diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index 2246bf8c..b5bfb36e 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -187,8 +187,8 @@ class SACPolicy(
         """Select action for inference/evaluation"""
         # We cached the encoder output to avoid recomputing it
         observations_features = None
-        if self.shared_encoder and self.actor.encoder is not None:
-            observations_features = self.actor.encoder(batch)
+        if self.shared_encoder:
+            observations_features = self.actor.encoder.get_image_features(batch)
 
         actions, _, _ = self.actor(batch, observations_features)
         actions = self.unnormalize_outputs({"action": actions})["action"]
@@ -484,6 +484,109 @@ class SACPolicy(
         return actor_loss
 
 
+class SACObservationEncoder(nn.Module):
+    """Encode image and/or state vector observations."""
+
+    def __init__(self, config: SACConfig, input_normalizer: nn.Module):
+        """
+        Creates encoders for pixel and/or state modalities.
+        """
+        super().__init__()
+        self.config = config
+        self.input_normalization = input_normalizer
+        self.has_pretrained_vision_encoder = False
+        self.parameters_to_optimize = []
+
+        self.aggregation_size: int = 0
+        if any("observation.image" in key for key in config.input_features):
+            self.camera_number = config.camera_number
+
+            if self.config.vision_encoder_name is not None:
+                self.image_enc_layers = PretrainedImageEncoder(config)
+                self.has_pretrained_vision_encoder = True
+            else:
+                self.image_enc_layers = DefaultImageEncoder(config)
+
+            self.aggregation_size += config.latent_dim * self.camera_number
+
+            if config.freeze_vision_encoder:
+                freeze_image_encoder(self.image_enc_layers)
+            else:
+                self.parameters_to_optimize += list(self.image_enc_layers.parameters())
+            self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
+
+        if "observation.state" in config.input_features:
+            self.state_enc_layers = nn.Sequential(
+                nn.Linear(
+                    in_features=config.input_features["observation.state"].shape[0],
+                    out_features=config.latent_dim,
+                ),
+                nn.LayerNorm(normalized_shape=config.latent_dim),
+                nn.Tanh(),
+            )
+            self.aggregation_size += config.latent_dim
+
+            self.parameters_to_optimize += list(self.state_enc_layers.parameters())
+
+        if "observation.environment_state" in config.input_features:
+            self.env_state_enc_layers = nn.Sequential(
+                nn.Linear(
+                    in_features=config.input_features["observation.environment_state"].shape[0],
+                    out_features=config.latent_dim,
+                ),
+                nn.LayerNorm(normalized_shape=config.latent_dim),
+                nn.Tanh(),
+            )
+            self.aggregation_size += config.latent_dim
+            self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
+
+        self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
+        self.parameters_to_optimize += list(self.aggregation_layer.parameters())
+
+    def forward(
+        self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
+    ) -> Tensor:
+        """Encode the image and/or state vector.
+
+        Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
+        over all features.
+        """
+        feat = []
+        obs_dict = self.input_normalization(obs_dict)
+        if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
+            vision_encoder_cache = self.get_image_features(obs_dict)
+            feat.append(vision_encoder_cache)
+
+        if vision_encoder_cache is not None:
+            feat.append(vision_encoder_cache)
+
+        if "observation.environment_state" in self.config.input_features:
+            feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
+        if "observation.state" in self.config.input_features:
+            feat.append(self.state_enc_layers(obs_dict["observation.state"]))
+
+        features = torch.cat(tensors=feat, dim=-1)
+        features = self.aggregation_layer(features)
+
+        return features
+
+    def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor:
+        # [N*B, C, H, W]
+        if len(self.all_image_keys) > 0:
+            # Batch all images along the batch dimension, then encode them.
+            images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
+            images_batched = self.image_enc_layers(images_batched)
+            embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
+            embeddings_image = torch.cat(embeddings_chunks, dim=-1)
+            return embeddings_image
+        return None
+
+    @property
+    def output_dim(self) -> int:
+        """Returns the dimension of the encoder output"""
+        return self.config.latent_dim
+
+
 class MLP(nn.Module):
     def __init__(
         self,
@@ -606,7 +709,7 @@ class CriticEnsemble(nn.Module):
 
     def __init__(
         self,
-        encoder: Optional[nn.Module],
+        encoder: SACObservationEncoder,
         ensemble: List[CriticHead],
         output_normalization: nn.Module,
         init_final: Optional[float] = None,
@@ -638,11 +741,7 @@ class CriticEnsemble(nn.Module):
         actions = self.output_normalization(actions)["action"]
         actions = actions.to(device)
 
-        obs_enc = (
-            observation_features
-            if observation_features is not None
-            else (observations if self.encoder is None else self.encoder(observations))
-        )
+        obs_enc = self.encoder(observations, observation_features)
 
         inputs = torch.cat([obs_enc, actions], dim=-1)
 
@@ -659,7 +758,7 @@ class CriticEnsemble(nn.Module):
 class GraspCritic(nn.Module):
     def __init__(
         self,
-        encoder: Optional[nn.Module],
+        encoder: nn.Module,
         input_dim: int,
         hidden_dims: list[int],
         output_dim: int = 3,
@@ -699,19 +798,14 @@ class GraspCritic(nn.Module):
         device = get_device_from_parameters(self)
         # Move each tensor in observations to device by cloning first to avoid inplace operations
         observations = {k: v.to(device) for k, v in observations.items()}
-        # Encode observations if encoder exists
-        obs_enc = (
-            observation_features.to(device)
-            if observation_features is not None
-            else (observations if self.encoder is None else self.encoder(observations))
-        )
+        obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
         return self.output_layer(self.net(obs_enc))
 
 
 class Policy(nn.Module):
     def __init__(
         self,
-        encoder: Optional[nn.Module],
+        encoder: SACObservationEncoder,
         network: nn.Module,
         action_dim: int,
         log_std_min: float = -5,
@@ -722,7 +816,7 @@ class Policy(nn.Module):
         encoder_is_shared: bool = False,
     ):
         super().__init__()
-        self.encoder = encoder
+        self.encoder: SACObservationEncoder = encoder
         self.network = network
         self.action_dim = action_dim
         self.log_std_min = log_std_min
@@ -765,11 +859,7 @@ class Policy(nn.Module):
         observation_features: torch.Tensor | None = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         # Encode observations if encoder exists
-        obs_enc = (
-            observation_features
-            if observation_features is not None
-            else (observations if self.encoder is None else self.encoder(observations))
-        )
+        obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
 
         # Get network outputs
         outputs = self.network(obs_enc)
@@ -813,96 +903,6 @@ class Policy(nn.Module):
         return observations
 
 
-class SACObservationEncoder(nn.Module):
-    """Encode image and/or state vector observations."""
-
-    def __init__(self, config: SACConfig, input_normalizer: nn.Module):
-        """
-        Creates encoders for pixel and/or state modalities.
-        """
-        super().__init__()
-        self.config = config
-        self.input_normalization = input_normalizer
-        self.has_pretrained_vision_encoder = False
-        self.parameters_to_optimize = []
-
-        self.aggregation_size: int = 0
-        if any("observation.image" in key for key in config.input_features):
-            self.camera_number = config.camera_number
-
-            if self.config.vision_encoder_name is not None:
-                self.image_enc_layers = PretrainedImageEncoder(config)
-                self.has_pretrained_vision_encoder = True
-            else:
-                self.image_enc_layers = DefaultImageEncoder(config)
-
-            self.aggregation_size += config.latent_dim * self.camera_number
-
-            if config.freeze_vision_encoder:
-                freeze_image_encoder(self.image_enc_layers)
-            else:
-                self.parameters_to_optimize += list(self.image_enc_layers.parameters())
-            self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
-
-        if "observation.state" in config.input_features:
-            self.state_enc_layers = nn.Sequential(
-                nn.Linear(
-                    in_features=config.input_features["observation.state"].shape[0],
-                    out_features=config.latent_dim,
-                ),
-                nn.LayerNorm(normalized_shape=config.latent_dim),
-                nn.Tanh(),
-            )
-            self.aggregation_size += config.latent_dim
-
-            self.parameters_to_optimize += list(self.state_enc_layers.parameters())
-
-        if "observation.environment_state" in config.input_features:
-            self.env_state_enc_layers = nn.Sequential(
-                nn.Linear(
-                    in_features=config.input_features["observation.environment_state"].shape[0],
-                    out_features=config.latent_dim,
-                ),
-                nn.LayerNorm(normalized_shape=config.latent_dim),
-                nn.Tanh(),
-            )
-            self.aggregation_size += config.latent_dim
-            self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
-
-        self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
-        self.parameters_to_optimize += list(self.aggregation_layer.parameters())
-
-    def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
-        """Encode the image and/or state vector.
-
-        Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
-        over all features.
-        """
-        feat = []
-        obs_dict = self.input_normalization(obs_dict)
-        # Batch all images along the batch dimension, then encode them.
-        if len(self.all_image_keys) > 0:
-            images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
-            images_batched = self.image_enc_layers(images_batched)
-            embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
-            feat.extend(embeddings_chunks)
-
-        if "observation.environment_state" in self.config.input_features:
-            feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
-        if "observation.state" in self.config.input_features:
-            feat.append(self.state_enc_layers(obs_dict["observation.state"]))
-
-        features = torch.cat(tensors=feat, dim=-1)
-        features = self.aggregation_layer(features)
-
-        return features
-
-    @property
-    def output_dim(self) -> int:
-        """Returns the dimension of the encoder output"""
-        return self.config.latent_dim
-
-
 class DefaultImageEncoder(nn.Module):
     def __init__(self, config: SACConfig):
         super().__init__()
diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py
index 65b1d9b8..37586fe9 100644
--- a/lerobot/scripts/server/learner_server.py
+++ b/lerobot/scripts/server/learner_server.py
@@ -775,7 +775,9 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
         params=policy.actor.parameters_to_optimize,
         lr=cfg.policy.actor_lr,
     )
-    optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
+    optimizer_critic = torch.optim.Adam(
+        params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr
+    )
 
     if cfg.policy.num_discrete_actions is not None:
         optimizer_grasp_critic = torch.optim.Adam(
@@ -1024,12 +1026,8 @@ def get_observation_features(
         return None, None
 
     with torch.no_grad():
-        observation_features = (
-            policy.actor.encoder(observations) if policy.actor.encoder is not None else None
-        )
-        next_observation_features = (
-            policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None
-        )
+        observation_features = policy.actor.encoder.get_image_features(observations)
+        next_observation_features = policy.actor.encoder.get_image_features(next_observations)
 
     return observation_features, next_observation_features
 

From 4621f4e4f37faf8ce94b467132a23da7b4a5c839 Mon Sep 17 00:00:00 2001
From: AdilZouitine <adilzouitinegm@gmail.com>
Date: Mon, 7 Apr 2025 08:23:49 +0000
Subject: [PATCH 18/28] Handle gripper penalty

---
 lerobot/common/policies/sac/modeling_sac.py |  11 +-
 lerobot/scripts/server/buffer.py            | 167 ++++++++++++++++----
 lerobot/scripts/server/learner_server.py    |   2 +-
 3 files changed, 147 insertions(+), 33 deletions(-)

diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index b5bfb36e..e3d3765e 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -288,6 +288,7 @@ class SACPolicy(
             next_observations: dict[str, Tensor] = batch["next_state"]
             done: Tensor = batch["done"]
             next_observation_features: Tensor = batch.get("next_observation_feature")
+            complementary_info = batch.get("complementary_info")
             loss_grasp_critic = self.compute_loss_grasp_critic(
                 observations=observations,
                 actions=actions,
@@ -296,6 +297,7 @@ class SACPolicy(
                 done=done,
                 observation_features=observation_features,
                 next_observation_features=next_observation_features,
+                complementary_info=complementary_info,
             )
             return {"loss_grasp_critic": loss_grasp_critic}
         if model == "actor":
@@ -413,6 +415,7 @@ class SACPolicy(
         done,
         observation_features=None,
         next_observation_features=None,
+        complementary_info=None,
     ):
         # NOTE: We only want to keep the discrete action part
         # In the buffer we have the full action space (continuous + discrete)
@@ -420,6 +423,9 @@ class SACPolicy(
         actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
         actions_discrete = actions_discrete.long()
 
+        if complementary_info is not None:
+            gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty")
+
         with torch.no_grad():
             # For DQN, select actions using online network, evaluate with target network
             next_grasp_qs = self.grasp_critic_forward(
@@ -440,7 +446,10 @@ class SACPolicy(
             ).squeeze(-1)
 
         # Compute target Q-value with Bellman equation
-        target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q
+        rewards_gripper = rewards
+        if gripper_penalties is not None:
+            rewards_gripper = rewards - gripper_penalties
+        target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
 
         # Get predicted Q-values for current observations
         predicted_grasp_qs = self.grasp_critic_forward(
diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py
index 8947f6d9..92ad7dc7 100644
--- a/lerobot/scripts/server/buffer.py
+++ b/lerobot/scripts/server/buffer.py
@@ -32,7 +32,7 @@ class Transition(TypedDict):
     next_state: dict[str, torch.Tensor]
     done: bool
     truncated: bool
-    complementary_info: dict[str, Any] = None
+    complementary_info: dict[str, torch.Tensor | float | int] | None = None
 
 
 class BatchTransition(TypedDict):
@@ -42,41 +42,43 @@ class BatchTransition(TypedDict):
     next_state: dict[str, torch.Tensor]
     done: torch.Tensor
     truncated: torch.Tensor
+    complementary_info: dict[str, torch.Tensor | float | int] | None = None
 
 
 def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
-    # Move state tensors to CPU
     device = torch.device(device)
+    non_blocking = device.type == "cuda"
+
+    # Move state tensors to device
     transition["state"] = {
-        key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
+        key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
     }
 
-    # Move action to CPU
-    transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
+    # Move action to device
+    transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
 
-    # No need to move reward or done, as they are float and bool
-
-    # No need to move reward or done, as they are float and bool
+    # Move reward and done if they are tensors
     if isinstance(transition["reward"], torch.Tensor):
-        transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
+        transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
 
     if isinstance(transition["done"], torch.Tensor):
-        transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
+        transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
 
     if isinstance(transition["truncated"], torch.Tensor):
-        transition["truncated"] = transition["truncated"].to(device, non_blocking=device.type == "cuda")
+        transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
 
-    # Move next_state tensors to CPU
+    # Move next_state tensors to device
     transition["next_state"] = {
-        key: val.to(device, non_blocking=device.type == "cuda")
-        for key, val in transition["next_state"].items()
+        key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
     }
 
-    # If complementary_info is present, move its tensors to CPU
-    # if transition["complementary_info"] is not None:
-    #     transition["complementary_info"] = {
-    #         key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
-    #     }
+    # Move complementary_info tensors if present
+    if transition.get("complementary_info") is not None:
+        transition["complementary_info"] = {
+            key: val.to(device, non_blocking=non_blocking)
+            for key, val in transition["complementary_info"].items()
+        }
+
     return transition
 
 
@@ -216,7 +218,12 @@ class ReplayBuffer:
             self.image_augmentation_function = torch.compile(base_function)
         self.use_drq = use_drq
 
-    def _initialize_storage(self, state: dict[str, torch.Tensor], action: torch.Tensor):
+    def _initialize_storage(
+        self,
+        state: dict[str, torch.Tensor],
+        action: torch.Tensor,
+        complementary_info: Optional[dict[str, torch.Tensor]] = None,
+    ):
         """Initialize the storage tensors based on the first transition."""
         # Determine shapes from the first transition
         state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
@@ -244,6 +251,26 @@ class ReplayBuffer:
         self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
         self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
 
+        # Initialize storage for complementary_info
+        self.has_complementary_info = complementary_info is not None
+        self.complementary_info_keys = []
+        self.complementary_info = {}
+
+        if self.has_complementary_info:
+            self.complementary_info_keys = list(complementary_info.keys())
+            # Pre-allocate tensors for each key in complementary_info
+            for key, value in complementary_info.items():
+                if isinstance(value, torch.Tensor):
+                    value_shape = value.squeeze(0).shape
+                    self.complementary_info[key] = torch.empty(
+                        (self.capacity, *value_shape), device=self.storage_device
+                    )
+                elif isinstance(value, (int, float)):
+                    # Handle scalar values similar to reward
+                    self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
+                else:
+                    raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]")
+
         self.initialized = True
 
     def __len__(self):
@@ -262,7 +289,7 @@ class ReplayBuffer:
         """Saves a transition, ensuring tensors are stored on the designated storage device."""
         # Initialize storage if this is the first transition
         if not self.initialized:
-            self._initialize_storage(state=state, action=action)
+            self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
 
         # Store the transition in pre-allocated tensors
         for key in self.states:
@@ -277,6 +304,17 @@ class ReplayBuffer:
         self.dones[self.position] = done
         self.truncateds[self.position] = truncated
 
+        # Handle complementary_info if provided and storage is initialized
+        if complementary_info is not None and self.has_complementary_info:
+            # Store the complementary_info
+            for key in self.complementary_info_keys:
+                if key in complementary_info:
+                    value = complementary_info[key]
+                    if isinstance(value, torch.Tensor):
+                        self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
+                    elif isinstance(value, (int, float)):
+                        self.complementary_info[key][self.position] = value
+
         self.position = (self.position + 1) % self.capacity
         self.size = min(self.size + 1, self.capacity)
 
@@ -335,6 +373,13 @@ class ReplayBuffer:
         batch_dones = self.dones[idx].to(self.device).float()
         batch_truncateds = self.truncateds[idx].to(self.device).float()
 
+        # Sample complementary_info if available
+        batch_complementary_info = None
+        if self.has_complementary_info:
+            batch_complementary_info = {}
+            for key in self.complementary_info_keys:
+                batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
+
         return BatchTransition(
             state=batch_state,
             action=batch_actions,
@@ -342,6 +387,7 @@ class ReplayBuffer:
             next_state=batch_next_state,
             done=batch_dones,
             truncated=batch_truncateds,
+            complementary_info=batch_complementary_info,
         )
 
     def get_iterator(
@@ -518,7 +564,19 @@ class ReplayBuffer:
             if action_delta is not None:
                 first_action = first_action / action_delta
 
-            replay_buffer._initialize_storage(state=first_state, action=first_action)
+            # Get complementary info if available
+            first_complementary_info = None
+            if (
+                "complementary_info" in first_transition
+                and first_transition["complementary_info"] is not None
+            ):
+                first_complementary_info = {
+                    k: v.to(device) for k, v in first_transition["complementary_info"].items()
+                }
+
+            replay_buffer._initialize_storage(
+                state=first_state, action=first_action, complementary_info=first_complementary_info
+            )
 
         # Fill the buffer with all transitions
         for data in list_transition:
@@ -546,6 +604,7 @@ class ReplayBuffer:
                 next_state=data["next_state"],
                 done=data["done"],
                 truncated=False,  # NOTE: Truncation are not supported yet in lerobot dataset
+                complementary_info=data.get("complementary_info", None),
             )
 
         return replay_buffer
@@ -587,6 +646,13 @@ class ReplayBuffer:
             f_info = guess_feature_info(t=sample_val, name=key)
             features[key] = f_info
 
+        # Add complementary_info keys if available
+        if self.has_complementary_info:
+            for key in self.complementary_info_keys:
+                sample_val = self.complementary_info[key][0]
+                f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}")
+                features[f"complementary_info.{key}"] = f_info
+
         # Create an empty LeRobotDataset
         lerobot_dataset = LeRobotDataset.create(
             repo_id=repo_id,
@@ -620,6 +686,11 @@ class ReplayBuffer:
             frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
             frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
 
+            # Add complementary_info if available
+            if self.has_complementary_info:
+                for key in self.complementary_info_keys:
+                    frame_dict[f"complementary_info.{key}"] = self.complementary_info[key][actual_idx].cpu()
+
             # Add task field which is required by LeRobotDataset
             frame_dict["task"] = task_name
 
@@ -686,6 +757,10 @@ class ReplayBuffer:
         sample = dataset[0]
         has_done_key = "next.done" in sample
 
+        # Check for complementary_info keys
+        complementary_info_keys = [key for key in sample.keys() if key.startswith("complementary_info.")]
+        has_complementary_info = len(complementary_info_keys) > 0
+
         # If not, we need to infer it from episode boundaries
         if not has_done_key:
             print("'next.done' key not found in dataset. Inferring from episode boundaries...")
@@ -735,6 +810,16 @@ class ReplayBuffer:
                         next_state_data[key] = val.unsqueeze(0)  # Add batch dimension
                     next_state = next_state_data
 
+            # ----- 5) Complementary info (if available) -----
+            complementary_info = None
+            if has_complementary_info:
+                complementary_info = {}
+                for key in complementary_info_keys:
+                    # Strip the "complementary_info." prefix to get the actual key
+                    clean_key = key[len("complementary_info.") :]
+                    val = current_sample[key]
+                    complementary_info[clean_key] = val.unsqueeze(0)  # Add batch dimension
+
             # ----- Construct the Transition -----
             transition = Transition(
                 state=current_state,
@@ -743,6 +828,7 @@ class ReplayBuffer:
                 next_state=next_state,
                 done=done,
                 truncated=truncated,
+                complementary_info=complementary_info,
             )
             transitions.append(transition)
 
@@ -775,32 +861,33 @@ def concatenate_batch_transitions(
     left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
 ) -> BatchTransition:
     """NOTE: Be careful it change the left_batch_transitions in place"""
+    # Concatenate state fields
     left_batch_transitions["state"] = {
         key: torch.cat(
-            [
-                left_batch_transitions["state"][key],
-                right_batch_transition["state"][key],
-            ],
+            [left_batch_transitions["state"][key], right_batch_transition["state"][key]],
             dim=0,
         )
         for key in left_batch_transitions["state"]
     }
+
+    # Concatenate basic fields
     left_batch_transitions["action"] = torch.cat(
         [left_batch_transitions["action"], right_batch_transition["action"]], dim=0
     )
     left_batch_transitions["reward"] = torch.cat(
         [left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
     )
+
+    # Concatenate next_state fields
     left_batch_transitions["next_state"] = {
         key: torch.cat(
-            [
-                left_batch_transitions["next_state"][key],
-                right_batch_transition["next_state"][key],
-            ],
+            [left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]],
             dim=0,
         )
         for key in left_batch_transitions["next_state"]
     }
+
+    # Concatenate done and truncated fields
     left_batch_transitions["done"] = torch.cat(
         [left_batch_transitions["done"], right_batch_transition["done"]], dim=0
     )
@@ -808,8 +895,26 @@ def concatenate_batch_transitions(
         [left_batch_transitions["truncated"], right_batch_transition["truncated"]],
         dim=0,
     )
+
+    # Handle complementary_info
+    left_info = left_batch_transitions.get("complementary_info")
+    right_info = right_batch_transition.get("complementary_info")
+
+    # Only process if right_info exists
+    if right_info is not None:
+        # Initialize left complementary_info if needed
+        if left_info is None:
+            left_batch_transitions["complementary_info"] = right_info
+        else:
+            # Concatenate each field
+            for key in right_info:
+                if key in left_info:
+                    left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0)
+                else:
+                    left_info[key] = right_info[key]
+
     return left_batch_transitions
 
 
 if __name__ == "__main__":
-    pass  # All test code is currently commented out
+    pass
diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py
index 37586fe9..5489d6dc 100644
--- a/lerobot/scripts/server/learner_server.py
+++ b/lerobot/scripts/server/learner_server.py
@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+# !/usr/bin/env python
 
 # Copyright 2024 The HuggingFace Inc. team.
 # All rights reserved.

From 6c103906532c542c2ba763b5300afd2abc25dda8 Mon Sep 17 00:00:00 2001
From: AdilZouitine <adilzouitinegm@gmail.com>
Date: Mon, 7 Apr 2025 14:48:42 +0000
Subject: [PATCH 19/28] Refactor complementary_info handling in ReplayBuffer

---
 lerobot/scripts/server/buffer.py | 132 ++++++++++++++++++++++++++++---
 1 file changed, 120 insertions(+), 12 deletions(-)

diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py
index 92ad7dc7..185412fd 100644
--- a/lerobot/scripts/server/buffer.py
+++ b/lerobot/scripts/server/buffer.py
@@ -74,11 +74,15 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
 
     # Move complementary_info tensors if present
     if transition.get("complementary_info") is not None:
-        transition["complementary_info"] = {
-            key: val.to(device, non_blocking=non_blocking)
-            for key, val in transition["complementary_info"].items()
-        }
-
+        for key, val in transition["complementary_info"].items():
+            if isinstance(val, torch.Tensor):
+                transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
+            elif isinstance(val, (int, float, bool)):
+                transition["complementary_info"][key] = torch.tensor(
+                    val, device=device, non_blocking=non_blocking
+                )
+            else:
+                raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
     return transition
 
 
@@ -650,6 +654,8 @@ class ReplayBuffer:
         if self.has_complementary_info:
             for key in self.complementary_info_keys:
                 sample_val = self.complementary_info[key][0]
+                if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0:
+                    sample_val = sample_val.unsqueeze(0)
                 f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}")
                 features[f"complementary_info.{key}"] = f_info
 
@@ -689,7 +695,15 @@ class ReplayBuffer:
             # Add complementary_info if available
             if self.has_complementary_info:
                 for key in self.complementary_info_keys:
-                    frame_dict[f"complementary_info.{key}"] = self.complementary_info[key][actual_idx].cpu()
+                    val = self.complementary_info[key][actual_idx]
+                    # Convert tensors to CPU
+                    if isinstance(val, torch.Tensor):
+                        if val.ndim == 0:
+                            val = val.unsqueeze(0)
+                        frame_dict[f"complementary_info.{key}"] = val.cpu()
+                    # Non-tensor values can be used directly
+                    else:
+                        frame_dict[f"complementary_info.{key}"] = val
 
             # Add task field which is required by LeRobotDataset
             frame_dict["task"] = task_name
@@ -758,7 +772,7 @@ class ReplayBuffer:
         has_done_key = "next.done" in sample
 
         # Check for complementary_info keys
-        complementary_info_keys = [key for key in sample.keys() if key.startswith("complementary_info.")]
+        complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")]
         has_complementary_info = len(complementary_info_keys) > 0
 
         # If not, we need to infer it from episode boundaries
@@ -818,7 +832,13 @@ class ReplayBuffer:
                     # Strip the "complementary_info." prefix to get the actual key
                     clean_key = key[len("complementary_info.") :]
                     val = current_sample[key]
-                    complementary_info[clean_key] = val.unsqueeze(0)  # Add batch dimension
+                    # Handle tensor and non-tensor values differently
+                    if isinstance(val, torch.Tensor):
+                        complementary_info[clean_key] = val.unsqueeze(0)  # Add batch dimension
+                    else:
+                        # TODO: (azouitine) Check if it's necessary to convert to tensor
+                        # For non-tensor values, use directly
+                        complementary_info[clean_key] = val
 
             # ----- Construct the Transition -----
             transition = Transition(
@@ -836,12 +856,13 @@ class ReplayBuffer:
 
 
 # Utility function to guess shapes/dtypes from a tensor
-def guess_feature_info(t: torch.Tensor, name: str):
+def guess_feature_info(t, name: str):
     """
-    Return a dictionary with the 'dtype' and 'shape' for a given tensor or array.
+    Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value.
     If it looks like a 3D (C,H,W) shape, we might consider it an 'image'.
-    Otherwise default to 'float32' for numeric. You can customize as needed.
+    Otherwise default to appropriate dtype for numeric.
     """
+
     shape = tuple(t.shape)
     # Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image'
     if len(shape) == 3 and shape[0] in [1, 3]:
@@ -917,4 +938,91 @@ def concatenate_batch_transitions(
 
 
 if __name__ == "__main__":
-    pass
+
+    def test_load_dataset_with_complementary_info():
+        """
+        Test loading a dataset with complementary_info into a ReplayBuffer.
+        The dataset 'aractingi/pick_lift_cube_two_cameras_gripper_penalty' contains
+        gripper_penalty values in complementary_info.
+        """
+        import time
+        from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
+
+        print("Loading dataset with complementary info...")
+        # Load a small subset of the dataset (first episode)
+        dataset = LeRobotDataset(
+            repo_id="aractingi/pick_lift_cube_two_cameras_gripper_penalty",
+        )
+
+        print(f"Dataset loaded with {len(dataset)} frames")
+        print(f"Dataset features: {list(dataset.features.keys())}")
+
+        # Check if dataset has complementary_info.gripper_penalty
+        sample = dataset[0]
+        complementary_info_keys = [key for key in sample if key.startswith("complementary_info")]
+        print(f"Complementary info keys: {complementary_info_keys}")
+
+        if "complementary_info.gripper_penalty" in sample:
+            print(f"Found gripper_penalty: {sample['complementary_info.gripper_penalty']}")
+
+        # Extract state keys for the buffer
+        state_keys = []
+        for key in sample:
+            if key.startswith("observation"):
+                state_keys.append(key)
+
+        print(f"Using state keys: {state_keys}")
+
+        # Create a replay buffer from the dataset
+        start_time = time.time()
+        buffer = ReplayBuffer.from_lerobot_dataset(
+            lerobot_dataset=dataset, state_keys=state_keys, use_drq=True, optimize_memory=False
+        )
+        load_time = time.time() - start_time
+        print(f"Loaded dataset into buffer in {load_time:.2f} seconds")
+        print(f"Buffer size: {len(buffer)}")
+
+        # Check if complementary_info was transferred correctly
+        print("Sampling from buffer to check complementary_info...")
+        batch = buffer.sample(batch_size=4)
+
+        if batch["complementary_info"] is not None:
+            print("Complementary info in batch:")
+            for key, value in batch["complementary_info"].items():
+                print(f"  {key}: {type(value)}, shape: {value.shape if hasattr(value, 'shape') else 'N/A'}")
+                if key == "gripper_penalty":
+                    print(f"  Sample gripper_penalty values: {value[:5]}")
+        else:
+            print("No complementary_info found in batch")
+
+        # Now convert the buffer back to a LeRobotDataset
+        print("\nConverting buffer back to LeRobotDataset...")
+        start_time = time.time()
+        new_dataset = buffer.to_lerobot_dataset(
+            repo_id="test_dataset_from_buffer",
+            fps=dataset.fps,
+            root="./test_dataset_from_buffer",
+            task_name="test_conversion",
+        )
+        convert_time = time.time() - start_time
+        print(f"Converted buffer to dataset in {convert_time:.2f} seconds")
+        print(f"New dataset size: {len(new_dataset)} frames")
+
+        # Check if complementary_info was preserved
+        new_sample = new_dataset[0]
+        new_complementary_info_keys = [key for key in new_sample if key.startswith("complementary_info")]
+        print(f"New dataset complementary info keys: {new_complementary_info_keys}")
+
+        if "complementary_info.gripper_penalty" in new_sample:
+            print(f"Found gripper_penalty in new dataset: {new_sample['complementary_info.gripper_penalty']}")
+
+        # Compare original and new datasets
+        print("\nComparing original and new datasets:")
+        print(f"Original dataset frames: {len(dataset)}, New dataset frames: {len(new_dataset)}")
+        print(f"Original features: {list(dataset.features.keys())}")
+        print(f"New features: {list(new_dataset.features.keys())}")
+
+        return buffer, dataset, new_dataset
+
+    # Run the test
+    test_load_dataset_with_complementary_info()

From 632b2b46c1f7dca5338b87b1d20c3306c1719ff6 Mon Sep 17 00:00:00 2001
From: AdilZouitine <adilzouitinegm@gmail.com>
Date: Mon, 7 Apr 2025 15:44:06 +0000
Subject: [PATCH 20/28] fix sign issue

---
 lerobot/common/policies/sac/modeling_sac.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index e3d3765e..9b909813 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -448,7 +448,7 @@ class SACPolicy(
         # Compute target Q-value with Bellman equation
         rewards_gripper = rewards
         if gripper_penalties is not None:
-            rewards_gripper = rewards - gripper_penalties
+            rewards_gripper = rewards + gripper_penalties
         target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
 
         # Get predicted Q-values for current observations

From a7be613ee8329d3f3a2784d314ad599605b72c4a Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Mon, 7 Apr 2025 15:48:39 +0000
Subject: [PATCH 21/28] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 lerobot/scripts/server/buffer.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py
index 185412fd..8db1a82c 100644
--- a/lerobot/scripts/server/buffer.py
+++ b/lerobot/scripts/server/buffer.py
@@ -946,6 +946,7 @@ if __name__ == "__main__":
         gripper_penalty values in complementary_info.
         """
         import time
+
         from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
 
         print("Loading dataset with complementary info...")

From a8135629b4a22141fcae2ea86c050b49e4374d19 Mon Sep 17 00:00:00 2001
From: AdilZouitine <adilzouitinegm@gmail.com>
Date: Tue, 8 Apr 2025 08:50:02 +0000
Subject: [PATCH 22/28] Add rounding for safety

---
 lerobot/common/policies/sac/modeling_sac.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index 9b909813..e3d83d36 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -421,6 +421,7 @@ class SACPolicy(
         # In the buffer we have the full action space (continuous + discrete)
         # We need to split them before concatenating them in the critic forward
         actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
+        actions_discrete = torch.round(actions_discrete)
         actions_discrete = actions_discrete.long()
 
         if complementary_info is not None:

From d948b95d22b6c9ae65a691c8bd16255a3929b03b Mon Sep 17 00:00:00 2001
From: AdilZouitine <adilzouitinegm@gmail.com>
Date: Wed, 9 Apr 2025 13:20:51 +0000
Subject: [PATCH 23/28] fix caching and dataset stats is optional

---
 .../common/policies/sac/configuration_sac.py  |  2 +-
 lerobot/common/policies/sac/modeling_sac.py   | 32 +++++++++++--------
 lerobot/scripts/server/learner_server.py      |  4 +--
 3 files changed, 22 insertions(+), 16 deletions(-)

diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py
index 3d01f47c..684ac17f 100644
--- a/lerobot/common/policies/sac/configuration_sac.py
+++ b/lerobot/common/policies/sac/configuration_sac.py
@@ -120,7 +120,7 @@ class SACConfig(PreTrainedConfig):
         }
     )
 
-    dataset_stats: dict[str, dict[str, list[float]]] = field(
+    dataset_stats: dict[str, dict[str, list[float]]] | None = field(
         default_factory=lambda: {
             "observation.image": {
                 "mean": [0.485, 0.456, 0.406],
diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index e3d83d36..34707dc4 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -63,16 +63,21 @@ class SACPolicy(
         else:
             self.normalize_inputs = nn.Identity()
 
-        output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
 
-        # HACK: This is hacky and should be removed
-        dataset_stats = dataset_stats or output_normalization_params
-        self.normalize_targets = Normalize(
-            config.output_features, config.normalization_mapping, dataset_stats
-        )
-        self.unnormalize_outputs = Unnormalize(
-            config.output_features, config.normalization_mapping, dataset_stats
-        )
+        if config.dataset_stats is not None:
+            output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
+
+            # HACK: This is hacky and should be removed
+            dataset_stats = dataset_stats or output_normalization_params
+            self.normalize_targets = Normalize(
+                config.output_features, config.normalization_mapping, dataset_stats
+            )
+            self.unnormalize_outputs = Unnormalize(
+                config.output_features, config.normalization_mapping, dataset_stats
+            )
+        else:
+            self.normalize_targets = nn.Identity()
+            self.unnormalize_outputs = nn.Identity()
 
         # NOTE: For images the encoder should be shared between the actor and critic
         if config.shared_encoder:
@@ -188,7 +193,7 @@ class SACPolicy(
         # We cached the encoder output to avoid recomputing it
         observations_features = None
         if self.shared_encoder:
-            observations_features = self.actor.encoder.get_image_features(batch)
+            observations_features = self.actor.encoder.get_image_features(batch, normalize=True)
 
         actions, _, _ = self.actor(batch, observations_features)
         actions = self.unnormalize_outputs({"action": actions})["action"]
@@ -564,8 +569,7 @@ class SACObservationEncoder(nn.Module):
         feat = []
         obs_dict = self.input_normalization(obs_dict)
         if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
-            vision_encoder_cache = self.get_image_features(obs_dict)
-            feat.append(vision_encoder_cache)
+            vision_encoder_cache = self.get_image_features(obs_dict, normalize=False)
 
         if vision_encoder_cache is not None:
             feat.append(vision_encoder_cache)
@@ -580,8 +584,10 @@ class SACObservationEncoder(nn.Module):
 
         return features
 
-    def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor:
+    def get_image_features(self, batch: dict[str, Tensor], normalize: bool = True) -> torch.Tensor:
         # [N*B, C, H, W]
+        if normalize:
+            batch = self.input_normalization(batch)
         if len(self.all_image_keys) > 0:
             # Batch all images along the batch dimension, then encode them.
             images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py
index 5489d6dc..707547a1 100644
--- a/lerobot/scripts/server/learner_server.py
+++ b/lerobot/scripts/server/learner_server.py
@@ -1026,8 +1026,8 @@ def get_observation_features(
         return None, None
 
     with torch.no_grad():
-        observation_features = policy.actor.encoder.get_image_features(observations)
-        next_observation_features = policy.actor.encoder.get_image_features(next_observations)
+        observation_features = policy.actor.encoder.get_image_features(observations, normalize=True)
+        next_observation_features = policy.actor.encoder.get_image_features(next_observations, normalize=True)
 
     return observation_features, next_observation_features
 

From e7edf2a8d8d6ca58fbfc7c93c8ae95446e5f5866 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Wed, 9 Apr 2025 13:51:31 +0000
Subject: [PATCH 24/28] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 lerobot/common/policies/sac/modeling_sac.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index 34707dc4..b51f9b8f 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -63,7 +63,6 @@ class SACPolicy(
         else:
             self.normalize_inputs = nn.Identity()
 
-
         if config.dataset_stats is not None:
             output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
 

From 5428ab96f5400ef298ac89daddb72e524f556fae Mon Sep 17 00:00:00 2001
From: Michel Aractingi <michel.aractingi@gmail.com>
Date: Wed, 9 Apr 2025 17:04:43 +0200
Subject: [PATCH 25/28] General fixes in code, removed delta action, fixed
 grasp penalty, added logic to put gripper reward in info

---
 lerobot/common/envs/configs.py                |   8 +-
 lerobot/common/policies/sac/modeling_sac.py   |   1 +
 lerobot/common/robot_devices/control_utils.py |   3 -
 lerobot/scripts/control_robot.py              |   2 -
 lerobot/scripts/server/buffer.py              |  12 +-
 lerobot/scripts/server/gym_manipulator.py     | 110 +++++++++++-------
 lerobot/scripts/server/learner_server.py      |   4 +-
 7 files changed, 75 insertions(+), 65 deletions(-)

diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py
index a6eda93b..02911332 100644
--- a/lerobot/common/envs/configs.py
+++ b/lerobot/common/envs/configs.py
@@ -171,7 +171,6 @@ class VideoRecordConfig:
 class WrapperConfig:
     """Configuration for environment wrappers."""
 
-    delta_action: float | None = None
     joint_masking_action_space: list[bool] | None = None
 
 
@@ -191,7 +190,6 @@ class EnvWrapperConfig:
     """Configuration for environment wrappers."""
 
     display_cameras: bool = False
-    delta_action: float = 0.1
     use_relative_joint_positions: bool = True
     add_joint_velocity_to_observation: bool = False
     add_ee_pose_to_observation: bool = False
@@ -203,11 +201,13 @@ class EnvWrapperConfig:
     joint_masking_action_space: Optional[Any] = None
     ee_action_space_params: Optional[EEActionSpaceConfig] = None
     use_gripper: bool = False
-    gripper_quantization_threshold: float = 0.8
-    gripper_penalty: float = 0.0
+    gripper_quantization_threshold: float | None = 0.8
+    gripper_penalty: float = 0.0 
+    gripper_penalty_in_reward: bool = False
     open_gripper_on_reset: bool = False
 
 
+
 @EnvConfig.register_subclass(name="gym_manipulator")
 @dataclass
 class HILSerlRobotEnvConfig(EnvConfig):
diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index b51f9b8f..b8827a1b 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -428,6 +428,7 @@ class SACPolicy(
         actions_discrete = torch.round(actions_discrete)
         actions_discrete = actions_discrete.long()
 
+        gripper_penalties: Tensor | None = None
         if complementary_info is not None:
             gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty")
 
diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py
index c834e9e9..170a35de 100644
--- a/lerobot/common/robot_devices/control_utils.py
+++ b/lerobot/common/robot_devices/control_utils.py
@@ -221,7 +221,6 @@ def record_episode(
         events=events,
         policy=policy,
         fps=fps,
-        # record_delta_actions=record_delta_actions,
         teleoperate=policy is None,
         single_task=single_task,
     )
@@ -267,8 +266,6 @@ def control_loop(
 
         if teleoperate:
             observation, action = robot.teleop_step(record_data=True)
-            # if record_delta_actions:
-            #     action["action"] = action["action"] - current_joint_positions
         else:
             observation = robot.capture_observation()
 
diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py
index 1013001a..658371a1 100644
--- a/lerobot/scripts/control_robot.py
+++ b/lerobot/scripts/control_robot.py
@@ -363,8 +363,6 @@ def replay(
         start_episode_t = time.perf_counter()
 
         action = actions[idx]["action"]
-        # if replay_delta_actions:
-        #     action = action + current_joint_positions
         robot.send_action(action)
 
         dt_s = time.perf_counter() - start_episode_t
diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py
index 8db1a82c..92e03d33 100644
--- a/lerobot/scripts/server/buffer.py
+++ b/lerobot/scripts/server/buffer.py
@@ -78,9 +78,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
             if isinstance(val, torch.Tensor):
                 transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
             elif isinstance(val, (int, float, bool)):
-                transition["complementary_info"][key] = torch.tensor(
-                    val, device=device, non_blocking=non_blocking
-                )
+                transition["complementary_info"][key] = torch.tensor(val, device=device)
             else:
                 raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
     return transition
@@ -505,7 +503,6 @@ class ReplayBuffer:
         state_keys: Optional[Sequence[str]] = None,
         capacity: Optional[int] = None,
         action_mask: Optional[Sequence[int]] = None,
-        action_delta: Optional[float] = None,
         image_augmentation_function: Optional[Callable] = None,
         use_drq: bool = True,
         storage_device: str = "cpu",
@@ -520,7 +517,6 @@ class ReplayBuffer:
             state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`.
             capacity (Optional[int]): Buffer capacity. If None, uses dataset length.
             action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep.
-            action_delta (Optional[float]): Factor to divide actions by.
             image_augmentation_function (Optional[Callable]): Function for image augmentation.
                 If None, uses default random shift with pad=4.
             use_drq (bool): Whether to use DrQ image augmentation when sampling.
@@ -565,9 +561,6 @@ class ReplayBuffer:
                 else:
                     first_action = first_action[:, action_mask]
 
-            if action_delta is not None:
-                first_action = first_action / action_delta
-
             # Get complementary info if available
             first_complementary_info = None
             if (
@@ -598,9 +591,6 @@ class ReplayBuffer:
                 else:
                     action = action[:, action_mask]
 
-            if action_delta is not None:
-                action = action / action_delta
-
             replay_buffer.add(
                 state=data["state"],
                 action=action,
diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py
index 3aa75466..44bbcf9b 100644
--- a/lerobot/scripts/server/gym_manipulator.py
+++ b/lerobot/scripts/server/gym_manipulator.py
@@ -42,7 +42,6 @@ class HILSerlRobotEnv(gym.Env):
         self,
         robot,
         use_delta_action_space: bool = True,
-        delta: float | None = None,
         display_cameras: bool = False,
     ):
         """
@@ -55,8 +54,6 @@ class HILSerlRobotEnv(gym.Env):
             robot: The robot interface object used to connect and interact with the physical robot.
             use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
                 joint positions are used.
-            delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between
-                0 and 1 when using a delta action space.
             display_cameras (bool): If True, the robot's camera feeds will be displayed during execution.
         """
         super().__init__()
@@ -74,7 +71,6 @@ class HILSerlRobotEnv(gym.Env):
         self.current_step = 0
         self.episode_data = None
 
-        self.delta = delta
         self.use_delta_action_space = use_delta_action_space
         self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
 
@@ -374,7 +370,7 @@ class RewardWrapper(gym.Wrapper):
         self.device = device
 
     def step(self, action):
-        observation, _, terminated, truncated, info = self.env.step(action)
+        observation, reward, terminated, truncated, info = self.env.step(action)
         images = [
             observation[key].to(self.device, non_blocking=self.device.type == "cuda")
             for key in observation
@@ -382,15 +378,17 @@ class RewardWrapper(gym.Wrapper):
         ]
         start_time = time.perf_counter()
         with torch.inference_mode():
-            reward = (
+            success = (
                 self.reward_classifier.predict_reward(images, threshold=0.8)
                 if self.reward_classifier is not None
                 else 0.0
             )
         info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
 
-        if reward == 1.0:
+        if success == 1.0:
             terminated = True
+            reward = 1.0
+
         return observation, reward, terminated, truncated, info
 
     def reset(self, seed=None, options=None):
@@ -720,19 +718,31 @@ class ResetWrapper(gym.Wrapper):
         env: HILSerlRobotEnv,
         reset_pose: np.ndarray | None = None,
         reset_time_s: float = 5,
+        open_gripper_on_reset: bool = False
     ):
         super().__init__(env)
         self.reset_time_s = reset_time_s
         self.reset_pose = reset_pose
         self.robot = self.unwrapped.robot
+        self.open_gripper_on_reset = open_gripper_on_reset
 
     def reset(self, *, seed=None, options=None):
+
+
         if self.reset_pose is not None:
             start_time = time.perf_counter()
             log_say("Reset the environment.", play_sounds=True)
             reset_follower_position(self.robot, self.reset_pose)
             busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
             log_say("Reset the environment done.", play_sounds=True)
+            if self.open_gripper_on_reset:
+                current_joint_pos = self.robot.follower_arms["main"].read("Present_Position")
+                current_joint_pos[-1] = MAX_GRIPPER_COMMAND
+                self.robot.send_action(torch.from_numpy(current_joint_pos))
+                busy_wait(0.1)
+                current_joint_pos[-1] = 0.0
+                self.robot.send_action(torch.from_numpy(current_joint_pos))
+                busy_wait(0.2)
         else:
             log_say(
                 f"Manually reset the environment for {self.reset_time_s} seconds.",
@@ -762,37 +772,48 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
 
 
 class GripperPenaltyWrapper(gym.RewardWrapper):
-    def __init__(self, env, penalty: float = -0.1):
+    def __init__(self, env, penalty: float = -0.1, gripper_penalty_in_reward: bool = True):
         super().__init__(env)
         self.penalty = penalty
+        self.gripper_penalty_in_reward = gripper_penalty_in_reward
         self.last_gripper_state = None
+        
 
     def reward(self, reward, action):
         gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
 
-        if isinstance(action, tuple):
-            action = action[0]
-        action_normalized = action[-1] / MAX_GRIPPER_COMMAND
+        action_normalized = action - 1.0 #action / MAX_GRIPPER_COMMAND
 
-        gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or (
-            gripper_state_normalized > 0.9 and action_normalized < 0.1
+        gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or (
+            gripper_state_normalized > 0.75 and action_normalized < -0.5
         )
-        breakpoint()
 
-        return reward + self.penalty * gripper_penalty_bool
+        return reward + self.penalty * int(gripper_penalty_bool)
 
     def step(self, action):
         self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
+        if isinstance(action, tuple):
+            gripper_action = action[0][-1]
+        else:
+            gripper_action = action[-1]
         obs, reward, terminated, truncated, info = self.env.step(action)
-        reward = self.reward(reward, action)
+        gripper_penalty = self.reward(reward, gripper_action)
+
+        if self.gripper_penalty_in_reward:
+            reward += gripper_penalty
+        else:
+            info["gripper_penalty"] = gripper_penalty
+            
         return obs, reward, terminated, truncated, info
 
     def reset(self, **kwargs):
         self.last_gripper_state = None
-        return super().reset(**kwargs)
+        obs, info = super().reset(**kwargs)
+        if self.gripper_penalty_in_reward:
+            info["gripper_penalty"] = 0.0
+        return obs, info
 
-
-class GripperQuantizationWrapper(gym.ActionWrapper):
+class GripperActionWrapper(gym.ActionWrapper):
     def __init__(self, env, quantization_threshold: float = 0.2):
         super().__init__(env)
         self.quantization_threshold = quantization_threshold
@@ -801,16 +822,18 @@ class GripperQuantizationWrapper(gym.ActionWrapper):
         is_intervention = False
         if isinstance(action, tuple):
             action, is_intervention = action
-
         gripper_command = action[-1]
-        # Quantize gripper command to -1, 0 or 1
-        if gripper_command < -self.quantization_threshold:
-            gripper_command = -MAX_GRIPPER_COMMAND
-        elif gripper_command > self.quantization_threshold:
-            gripper_command = MAX_GRIPPER_COMMAND
-        else:
-            gripper_command = 0.0
 
+        # Gripper actions are between 0, 2
+        # we want to quantize them to -1, 0 or 1
+        gripper_command = gripper_command - 1.0
+
+        if self.quantization_threshold is not None:
+            # Quantize gripper command to -1, 0 or 1
+            gripper_command = (
+                np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0
+            )
+        gripper_command = gripper_command * MAX_GRIPPER_COMMAND
         gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
         gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
         action[-1] = gripper_action.item()
@@ -836,10 +859,12 @@ class EEActionWrapper(gym.ActionWrapper):
             ]
         )
         if self.use_gripper:
-            action_space_bounds = np.concatenate([action_space_bounds, [1.0]])
+            # gripper actions open at 2.0, and closed at 0.0
+            min_action_space_bounds = np.concatenate([-action_space_bounds, [0.0]])
+            max_action_space_bounds = np.concatenate([action_space_bounds, [2.0]])
         ee_action_space = gym.spaces.Box(
-            low=-action_space_bounds,
-            high=action_space_bounds,
+            low=min_action_space_bounds,
+            high=max_action_space_bounds,
             shape=(3 + int(self.use_gripper),),
             dtype=np.float32,
         )
@@ -997,11 +1022,11 @@ class GamepadControlWrapper(gym.Wrapper):
         if self.use_gripper:
             gripper_command = self.controller.gripper_command()
             if gripper_command == "open":
-                gamepad_action = np.concatenate([gamepad_action, [1.0]])
+                gamepad_action = np.concatenate([gamepad_action, [2.0]])
             elif gripper_command == "close":
-                gamepad_action = np.concatenate([gamepad_action, [-1.0]])
-            else:
                 gamepad_action = np.concatenate([gamepad_action, [0.0]])
+            else:
+                gamepad_action = np.concatenate([gamepad_action, [1.0]])
 
         # Check episode ending buttons
         # We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
@@ -1141,7 +1166,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
     env = HILSerlRobotEnv(
         robot=robot,
         display_cameras=cfg.wrapper.display_cameras,
-        delta=cfg.wrapper.delta_action,
         use_delta_action_space=cfg.wrapper.use_relative_joint_positions
         and cfg.wrapper.ee_action_space_params is None,
     )
@@ -1165,10 +1189,11 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
     # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
     env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
     if cfg.wrapper.use_gripper:
-        env = GripperQuantizationWrapper(
+        env = GripperActionWrapper(
             env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
         )
-        # env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
+        if cfg.wrapper.gripper_penalty is not None:
+            env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward)
 
     if cfg.wrapper.ee_action_space_params is not None:
         env = EEActionWrapper(
@@ -1176,6 +1201,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
             ee_action_space_params=cfg.wrapper.ee_action_space_params,
             use_gripper=cfg.wrapper.use_gripper,
         )
+
     if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
         # env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
         env = GamepadControlWrapper(
@@ -1192,6 +1218,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
         env=env,
         reset_pose=cfg.wrapper.fixed_reset_joint_positions,
         reset_time_s=cfg.wrapper.reset_time_s,
+        open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset
     )
     if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
         env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
@@ -1341,11 +1368,10 @@ def record_dataset(env, policy, cfg):
         dataset.push_to_hub()
 
 
-def replay_episode(env, repo_id, root=None, episode=0):
+def replay_episode(env, cfg):
     from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
 
-    local_files_only = root is not None
-    dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
+    dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode])
     env.reset()
 
     actions = dataset.hf_dataset.select_columns("action")
@@ -1353,7 +1379,7 @@ def replay_episode(env, repo_id, root=None, episode=0):
     for idx in range(dataset.num_frames):
         start_episode_t = time.perf_counter()
 
-        action = actions[idx]["action"][:4]
+        action = actions[idx]["action"]
         env.step((action, False))
         # env.step((action / env.unwrapped.delta, False))
 
@@ -1384,9 +1410,7 @@ def main(cfg: EnvConfig):
     if cfg.mode == "replay":
         replay_episode(
             env,
-            cfg.replay_repo_id,
-            root=cfg.dataset_root,
-            episode=cfg.replay_episode,
+            cfg=cfg,
         )
         exit()
 
diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py
index 707547a1..e4bcc620 100644
--- a/lerobot/scripts/server/learner_server.py
+++ b/lerobot/scripts/server/learner_server.py
@@ -406,7 +406,8 @@ def add_actor_information_and_train(
                 "next_state": next_observations,
                 "done": done,
                 "observation_feature": observation_features,
-                "next_observation_feature": next_observation_features,
+                "next_observation_feature": next_observation_features, 
+                "complementary_info": batch["complementary_info"],
             }
 
             # Use the forward method for critic loss (includes both main critic and grasp critic)
@@ -992,7 +993,6 @@ def initialize_offline_replay_buffer(
         device=device,
         state_keys=cfg.policy.input_features.keys(),
         action_mask=active_action_dims,
-        action_delta=cfg.env.wrapper.delta_action,
         storage_device=storage_device,
         optimize_memory=True,
         capacity=cfg.policy.offline_buffer_capacity,

From ba09f44eb7f312c7f12287a4d56ae8d722633fbf Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Wed, 9 Apr 2025 15:05:17 +0000
Subject: [PATCH 26/28] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 lerobot/common/envs/configs.py            |  3 +--
 lerobot/scripts/server/gym_manipulator.py | 22 +++++++++++-----------
 lerobot/scripts/server/learner_server.py  |  2 +-
 3 files changed, 13 insertions(+), 14 deletions(-)

diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py
index 02911332..77220c3c 100644
--- a/lerobot/common/envs/configs.py
+++ b/lerobot/common/envs/configs.py
@@ -202,12 +202,11 @@ class EnvWrapperConfig:
     ee_action_space_params: Optional[EEActionSpaceConfig] = None
     use_gripper: bool = False
     gripper_quantization_threshold: float | None = 0.8
-    gripper_penalty: float = 0.0 
+    gripper_penalty: float = 0.0
     gripper_penalty_in_reward: bool = False
     open_gripper_on_reset: bool = False
 
 
-
 @EnvConfig.register_subclass(name="gym_manipulator")
 @dataclass
 class HILSerlRobotEnvConfig(EnvConfig):
diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py
index 44bbcf9b..6a8e848a 100644
--- a/lerobot/scripts/server/gym_manipulator.py
+++ b/lerobot/scripts/server/gym_manipulator.py
@@ -718,7 +718,7 @@ class ResetWrapper(gym.Wrapper):
         env: HILSerlRobotEnv,
         reset_pose: np.ndarray | None = None,
         reset_time_s: float = 5,
-        open_gripper_on_reset: bool = False
+        open_gripper_on_reset: bool = False,
     ):
         super().__init__(env)
         self.reset_time_s = reset_time_s
@@ -727,8 +727,6 @@ class ResetWrapper(gym.Wrapper):
         self.open_gripper_on_reset = open_gripper_on_reset
 
     def reset(self, *, seed=None, options=None):
-
-
         if self.reset_pose is not None:
             start_time = time.perf_counter()
             log_say("Reset the environment.", play_sounds=True)
@@ -777,12 +775,11 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
         self.penalty = penalty
         self.gripper_penalty_in_reward = gripper_penalty_in_reward
         self.last_gripper_state = None
-        
 
     def reward(self, reward, action):
         gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
 
-        action_normalized = action - 1.0 #action / MAX_GRIPPER_COMMAND
+        action_normalized = action - 1.0  # action / MAX_GRIPPER_COMMAND
 
         gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or (
             gripper_state_normalized > 0.75 and action_normalized < -0.5
@@ -803,7 +800,7 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
             reward += gripper_penalty
         else:
             info["gripper_penalty"] = gripper_penalty
-            
+
         return obs, reward, terminated, truncated, info
 
     def reset(self, **kwargs):
@@ -813,6 +810,7 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
             info["gripper_penalty"] = 0.0
         return obs, info
 
+
 class GripperActionWrapper(gym.ActionWrapper):
     def __init__(self, env, quantization_threshold: float = 0.2):
         super().__init__(env)
@@ -1189,11 +1187,13 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
     # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
     env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
     if cfg.wrapper.use_gripper:
-        env = GripperActionWrapper(
-            env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
-        )
+        env = GripperActionWrapper(env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold)
         if cfg.wrapper.gripper_penalty is not None:
-            env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward)
+            env = GripperPenaltyWrapper(
+                env=env,
+                penalty=cfg.wrapper.gripper_penalty,
+                gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward,
+            )
 
     if cfg.wrapper.ee_action_space_params is not None:
         env = EEActionWrapper(
@@ -1218,7 +1218,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
         env=env,
         reset_pose=cfg.wrapper.fixed_reset_joint_positions,
         reset_time_s=cfg.wrapper.reset_time_s,
-        open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset
+        open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset,
     )
     if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
         env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py
index e4bcc620..5b39d0d3 100644
--- a/lerobot/scripts/server/learner_server.py
+++ b/lerobot/scripts/server/learner_server.py
@@ -406,7 +406,7 @@ def add_actor_information_and_train(
                 "next_state": next_observations,
                 "done": done,
                 "observation_feature": observation_features,
-                "next_observation_feature": next_observation_features, 
+                "next_observation_feature": next_observation_features,
                 "complementary_info": batch["complementary_info"],
             }
 

From 854bfb4ff8548c2276e70de0852acd3d620b948e Mon Sep 17 00:00:00 2001
From: AdilZouitine <adilzouitinegm@gmail.com>
Date: Fri, 11 Apr 2025 11:50:46 +0000
Subject: [PATCH 27/28] fix encoder training

---
 lerobot/common/policies/sac/modeling_sac.py | 32 +++++++++++++--------
 1 file changed, 20 insertions(+), 12 deletions(-)

diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index b8827a1b..9ffdf154 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -525,9 +525,10 @@ class SACObservationEncoder(nn.Module):
             self.aggregation_size += config.latent_dim * self.camera_number
 
             if config.freeze_vision_encoder:
-                freeze_image_encoder(self.image_enc_layers)
-            else:
-                self.parameters_to_optimize += list(self.image_enc_layers.parameters())
+                freeze_image_encoder(self.image_enc_layers.image_enc_layers)
+
+            self.parameters_to_optimize += self.image_enc_layers.parameters_to_optimize
+
             self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
 
         if "observation.state" in config.input_features:
@@ -958,23 +959,25 @@ class DefaultImageEncoder(nn.Module):
         dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
         with torch.inference_mode():
             self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
-        self.image_enc_layers.extend(
-            nn.Sequential(
-                nn.Flatten(),
-                nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
-                nn.LayerNorm(config.latent_dim),
-                nn.Tanh(),
-            )
+        self.image_enc_proj = nn.Sequential(
+            nn.Flatten(),
+            nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
+            nn.LayerNorm(config.latent_dim),
+            nn.Tanh(),
         )
 
+        self.parameters_to_optimize = []
+        if not config.freeze_vision_encoder:
+            self.parameters_to_optimize += list(self.image_enc_layers.parameters())
+        self.parameters_to_optimize += list(self.image_enc_proj.parameters())
+
     def forward(self, x):
-        return self.image_enc_layers(x)
+        return self.image_enc_proj(self.image_enc_layers(x))
 
 
 class PretrainedImageEncoder(nn.Module):
     def __init__(self, config: SACConfig):
         super().__init__()
-
         self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
         self.image_enc_proj = nn.Sequential(
             nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
@@ -982,6 +985,11 @@ class PretrainedImageEncoder(nn.Module):
             nn.Tanh(),
         )
 
+        self.parameters_to_optimize = []
+        if not config.freeze_vision_encoder:
+            self.parameters_to_optimize += list(self.image_enc_layers.parameters())
+        self.parameters_to_optimize += list(self.image_enc_proj.parameters())
+
     def _load_pretrained_vision_encoder(self, config: SACConfig):
         """Set up CNN encoder"""
         from transformers import AutoModel

From 320a1a92a39e3cf91c5cd9b018d78119da326c0c Mon Sep 17 00:00:00 2001
From: AdilZouitine <adilzouitinegm@gmail.com>
Date: Mon, 14 Apr 2025 14:00:57 +0000
Subject: [PATCH 28/28] Refactor modeling_sac and parameter handling for
 clarity and reusability.

Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
---
 lerobot/common/policies/sac/modeling_sac.py | 62 ++++++++-------------
 lerobot/scripts/server/learner_server.py    | 50 +++++++++++++++--
 2 files changed, 67 insertions(+), 45 deletions(-)

diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index 9ffdf154..05937240 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -167,8 +167,12 @@ class SACPolicy(
 
     def get_optim_params(self) -> dict:
         optim_params = {
-            "actor": self.actor.parameters_to_optimize,
-            "critic": self.critic_ensemble.parameters_to_optimize,
+            "actor": [
+                p
+                for n, p in self.actor.named_parameters()
+                if not n.startswith("encoder") or not self.shared_encoder
+            ],
+            "critic": self.critic_ensemble.parameters(),
             "temperature": self.log_alpha,
         }
         if self.config.num_discrete_actions is not None:
@@ -451,11 +455,11 @@ class SACPolicy(
                 target_next_grasp_qs, dim=1, index=best_next_grasp_action
             ).squeeze(-1)
 
-        # Compute target Q-value with Bellman equation
-        rewards_gripper = rewards
-        if gripper_penalties is not None:
-            rewards_gripper = rewards + gripper_penalties
-        target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
+            # Compute target Q-value with Bellman equation
+            rewards_gripper = rewards
+            if gripper_penalties is not None:
+                rewards_gripper = rewards + gripper_penalties
+            target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
 
         # Get predicted Q-values for current observations
         predicted_grasp_qs = self.grasp_critic_forward(
@@ -510,7 +514,6 @@ class SACObservationEncoder(nn.Module):
         self.config = config
         self.input_normalization = input_normalizer
         self.has_pretrained_vision_encoder = False
-        self.parameters_to_optimize = []
 
         self.aggregation_size: int = 0
         if any("observation.image" in key for key in config.input_features):
@@ -527,8 +530,6 @@ class SACObservationEncoder(nn.Module):
             if config.freeze_vision_encoder:
                 freeze_image_encoder(self.image_enc_layers.image_enc_layers)
 
-            self.parameters_to_optimize += self.image_enc_layers.parameters_to_optimize
-
             self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
 
         if "observation.state" in config.input_features:
@@ -542,8 +543,6 @@ class SACObservationEncoder(nn.Module):
             )
             self.aggregation_size += config.latent_dim
 
-            self.parameters_to_optimize += list(self.state_enc_layers.parameters())
-
         if "observation.environment_state" in config.input_features:
             self.env_state_enc_layers = nn.Sequential(
                 nn.Linear(
@@ -554,10 +553,8 @@ class SACObservationEncoder(nn.Module):
                 nn.Tanh(),
             )
             self.aggregation_size += config.latent_dim
-            self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
 
         self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
-        self.parameters_to_optimize += list(self.aggregation_layer.parameters())
 
     def forward(
         self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
@@ -737,12 +734,6 @@ class CriticEnsemble(nn.Module):
         self.output_normalization = output_normalization
         self.critics = nn.ModuleList(ensemble)
 
-        self.parameters_to_optimize = []
-        # Handle the case where a part of the encoder if frozen
-        if self.encoder is not None:
-            self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
-        self.parameters_to_optimize += list(self.critics.parameters())
-
     def forward(
         self,
         observations: dict[str, torch.Tensor],
@@ -805,10 +796,6 @@ class GraspCritic(nn.Module):
         else:
             orthogonal_init()(self.output_layer.weight)
 
-        self.parameters_to_optimize = []
-        self.parameters_to_optimize += list(self.net.parameters())
-        self.parameters_to_optimize += list(self.output_layer.parameters())
-
     def forward(
         self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
     ) -> torch.Tensor:
@@ -840,12 +827,8 @@ class Policy(nn.Module):
         self.log_std_max = log_std_max
         self.fixed_std = fixed_std
         self.use_tanh_squash = use_tanh_squash
-        self.parameters_to_optimize = []
+        self.encoder_is_shared = encoder_is_shared
 
-        self.parameters_to_optimize += list(self.network.parameters())
-
-        if self.encoder is not None and not encoder_is_shared:
-            self.parameters_to_optimize += list(self.encoder.parameters())
         # Find the last Linear layer's output dimension
         for layer in reversed(network.net):
             if isinstance(layer, nn.Linear):
@@ -859,7 +842,6 @@ class Policy(nn.Module):
         else:
             orthogonal_init()(self.mean_layer.weight)
 
-        self.parameters_to_optimize += list(self.mean_layer.parameters())
         # Standard deviation layer or parameter
         if fixed_std is None:
             self.std_layer = nn.Linear(out_features, action_dim)
@@ -868,7 +850,6 @@ class Policy(nn.Module):
                 nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
             else:
                 orthogonal_init()(self.std_layer.weight)
-            self.parameters_to_optimize += list(self.std_layer.parameters())
 
     def forward(
         self,
@@ -877,6 +858,8 @@ class Policy(nn.Module):
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         # Encode observations if encoder exists
         obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
+        if self.encoder_is_shared:
+            obs_enc = obs_enc.detach()
 
         # Get network outputs
         outputs = self.network(obs_enc)
@@ -966,13 +949,13 @@ class DefaultImageEncoder(nn.Module):
             nn.Tanh(),
         )
 
-        self.parameters_to_optimize = []
-        if not config.freeze_vision_encoder:
-            self.parameters_to_optimize += list(self.image_enc_layers.parameters())
-        self.parameters_to_optimize += list(self.image_enc_proj.parameters())
+        self.freeze_image_encoder = config.freeze_vision_encoder
 
     def forward(self, x):
-        return self.image_enc_proj(self.image_enc_layers(x))
+        x = self.image_enc_layers(x)
+        if self.freeze_image_encoder:
+            x = x.detach()
+        return self.image_enc_proj(x)
 
 
 class PretrainedImageEncoder(nn.Module):
@@ -985,10 +968,7 @@ class PretrainedImageEncoder(nn.Module):
             nn.Tanh(),
         )
 
-        self.parameters_to_optimize = []
-        if not config.freeze_vision_encoder:
-            self.parameters_to_optimize += list(self.image_enc_layers.parameters())
-        self.parameters_to_optimize += list(self.image_enc_proj.parameters())
+        self.freeze_image_encoder = config.freeze_vision_encoder
 
     def _load_pretrained_vision_encoder(self, config: SACConfig):
         """Set up CNN encoder"""
@@ -1009,6 +989,8 @@ class PretrainedImageEncoder(nn.Module):
         # TODO: (maractingi, azouitine) check the forward pass of the pretrained model
         # doesn't reach the classifier layer because we don't need it
         enc_feat = self.image_enc_layers(x).pooler_output
+        if self.freeze_image_encoder:
+            enc_feat = enc_feat.detach()
         enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
         return enc_feat
 
diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py
index 5b39d0d3..a8a858bf 100644
--- a/lerobot/scripts/server/learner_server.py
+++ b/lerobot/scripts/server/learner_server.py
@@ -510,7 +510,7 @@ def add_actor_information_and_train(
                 optimizers["actor"].zero_grad()
                 loss_actor.backward()
                 actor_grad_norm = torch.nn.utils.clip_grad_norm_(
-                    parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
+                    parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
                 ).item()
                 optimizers["actor"].step()
 
@@ -773,12 +773,14 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
     """
     optimizer_actor = torch.optim.Adam(
         # NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
-        params=policy.actor.parameters_to_optimize,
+        params=[
+            p
+            for n, p in policy.actor.named_parameters()
+            if not n.startswith("encoder") or not policy.config.shared_encoder
+        ],
         lr=cfg.policy.actor_lr,
     )
-    optimizer_critic = torch.optim.Adam(
-        params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr
-    )
+    optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
 
     if cfg.policy.num_discrete_actions is not None:
         optimizer_grasp_critic = torch.optim.Adam(
@@ -1089,6 +1091,44 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
     parameters_queue.put(state_bytes)
 
 
+def check_weight_gradients(module: nn.Module) -> dict[str, bool]:
+    """
+    Checks whether each parameter in the module has a gradient.
+
+    Args:
+        module (nn.Module): A PyTorch module whose parameters will be inspected.
+
+    Returns:
+        dict[str, bool]: A dictionary where each key is the parameter name and the value is
+                         True if the parameter has an associated gradient (i.e. .grad is not None),
+                         otherwise False.
+    """
+    grad_status = {}
+    for name, param in module.named_parameters():
+        grad_status[name] = param.grad is not None
+    return grad_status
+
+
+def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]:
+    """
+    Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary.
+
+    Args:
+        actor (nn.Module): The actor model.
+        grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate
+                                       whether each parameter has a gradient.
+
+    Returns:
+        dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status.
+    """
+    # Get actor parameter names as a set.
+    model_param_names = {name for name, _ in model.named_parameters()}
+
+    # Intersect parameter names between actor and grad_status.
+    overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names}
+    return overlapping
+
+
 def process_interaction_message(
     message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
 ):