diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py
index b5bfb36e..e3d3765e 100644
--- a/lerobot/common/policies/sac/modeling_sac.py
+++ b/lerobot/common/policies/sac/modeling_sac.py
@@ -288,6 +288,7 @@ class SACPolicy(
             next_observations: dict[str, Tensor] = batch["next_state"]
             done: Tensor = batch["done"]
             next_observation_features: Tensor = batch.get("next_observation_feature")
+            complementary_info = batch.get("complementary_info")
             loss_grasp_critic = self.compute_loss_grasp_critic(
                 observations=observations,
                 actions=actions,
@@ -296,6 +297,7 @@ class SACPolicy(
                 done=done,
                 observation_features=observation_features,
                 next_observation_features=next_observation_features,
+                complementary_info=complementary_info,
             )
             return {"loss_grasp_critic": loss_grasp_critic}
         if model == "actor":
@@ -413,6 +415,7 @@ class SACPolicy(
         done,
         observation_features=None,
         next_observation_features=None,
+        complementary_info=None,
     ):
         # NOTE: We only want to keep the discrete action part
         # In the buffer we have the full action space (continuous + discrete)
@@ -420,6 +423,9 @@ class SACPolicy(
         actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
         actions_discrete = actions_discrete.long()
 
+        if complementary_info is not None:
+            gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty")
+
         with torch.no_grad():
             # For DQN, select actions using online network, evaluate with target network
             next_grasp_qs = self.grasp_critic_forward(
@@ -440,7 +446,10 @@ class SACPolicy(
             ).squeeze(-1)
 
         # Compute target Q-value with Bellman equation
-        target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q
+        rewards_gripper = rewards
+        if gripper_penalties is not None:
+            rewards_gripper = rewards - gripper_penalties
+        target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
 
         # Get predicted Q-values for current observations
         predicted_grasp_qs = self.grasp_critic_forward(
diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py
index 8947f6d9..92ad7dc7 100644
--- a/lerobot/scripts/server/buffer.py
+++ b/lerobot/scripts/server/buffer.py
@@ -32,7 +32,7 @@ class Transition(TypedDict):
     next_state: dict[str, torch.Tensor]
     done: bool
     truncated: bool
-    complementary_info: dict[str, Any] = None
+    complementary_info: dict[str, torch.Tensor | float | int] | None = None
 
 
 class BatchTransition(TypedDict):
@@ -42,41 +42,43 @@ class BatchTransition(TypedDict):
     next_state: dict[str, torch.Tensor]
     done: torch.Tensor
     truncated: torch.Tensor
+    complementary_info: dict[str, torch.Tensor | float | int] | None = None
 
 
 def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
-    # Move state tensors to CPU
     device = torch.device(device)
+    non_blocking = device.type == "cuda"
+
+    # Move state tensors to device
     transition["state"] = {
-        key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
+        key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
     }
 
-    # Move action to CPU
-    transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
+    # Move action to device
+    transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
 
-    # No need to move reward or done, as they are float and bool
-
-    # No need to move reward or done, as they are float and bool
+    # Move reward and done if they are tensors
     if isinstance(transition["reward"], torch.Tensor):
-        transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
+        transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
 
     if isinstance(transition["done"], torch.Tensor):
-        transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
+        transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
 
     if isinstance(transition["truncated"], torch.Tensor):
-        transition["truncated"] = transition["truncated"].to(device, non_blocking=device.type == "cuda")
+        transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
 
-    # Move next_state tensors to CPU
+    # Move next_state tensors to device
     transition["next_state"] = {
-        key: val.to(device, non_blocking=device.type == "cuda")
-        for key, val in transition["next_state"].items()
+        key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
     }
 
-    # If complementary_info is present, move its tensors to CPU
-    # if transition["complementary_info"] is not None:
-    #     transition["complementary_info"] = {
-    #         key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
-    #     }
+    # Move complementary_info tensors if present
+    if transition.get("complementary_info") is not None:
+        transition["complementary_info"] = {
+            key: val.to(device, non_blocking=non_blocking)
+            for key, val in transition["complementary_info"].items()
+        }
+
     return transition
 
 
@@ -216,7 +218,12 @@ class ReplayBuffer:
             self.image_augmentation_function = torch.compile(base_function)
         self.use_drq = use_drq
 
-    def _initialize_storage(self, state: dict[str, torch.Tensor], action: torch.Tensor):
+    def _initialize_storage(
+        self,
+        state: dict[str, torch.Tensor],
+        action: torch.Tensor,
+        complementary_info: Optional[dict[str, torch.Tensor]] = None,
+    ):
         """Initialize the storage tensors based on the first transition."""
         # Determine shapes from the first transition
         state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
@@ -244,6 +251,26 @@ class ReplayBuffer:
         self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
         self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
 
+        # Initialize storage for complementary_info
+        self.has_complementary_info = complementary_info is not None
+        self.complementary_info_keys = []
+        self.complementary_info = {}
+
+        if self.has_complementary_info:
+            self.complementary_info_keys = list(complementary_info.keys())
+            # Pre-allocate tensors for each key in complementary_info
+            for key, value in complementary_info.items():
+                if isinstance(value, torch.Tensor):
+                    value_shape = value.squeeze(0).shape
+                    self.complementary_info[key] = torch.empty(
+                        (self.capacity, *value_shape), device=self.storage_device
+                    )
+                elif isinstance(value, (int, float)):
+                    # Handle scalar values similar to reward
+                    self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
+                else:
+                    raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]")
+
         self.initialized = True
 
     def __len__(self):
@@ -262,7 +289,7 @@ class ReplayBuffer:
         """Saves a transition, ensuring tensors are stored on the designated storage device."""
         # Initialize storage if this is the first transition
         if not self.initialized:
-            self._initialize_storage(state=state, action=action)
+            self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
 
         # Store the transition in pre-allocated tensors
         for key in self.states:
@@ -277,6 +304,17 @@ class ReplayBuffer:
         self.dones[self.position] = done
         self.truncateds[self.position] = truncated
 
+        # Handle complementary_info if provided and storage is initialized
+        if complementary_info is not None and self.has_complementary_info:
+            # Store the complementary_info
+            for key in self.complementary_info_keys:
+                if key in complementary_info:
+                    value = complementary_info[key]
+                    if isinstance(value, torch.Tensor):
+                        self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
+                    elif isinstance(value, (int, float)):
+                        self.complementary_info[key][self.position] = value
+
         self.position = (self.position + 1) % self.capacity
         self.size = min(self.size + 1, self.capacity)
 
@@ -335,6 +373,13 @@ class ReplayBuffer:
         batch_dones = self.dones[idx].to(self.device).float()
         batch_truncateds = self.truncateds[idx].to(self.device).float()
 
+        # Sample complementary_info if available
+        batch_complementary_info = None
+        if self.has_complementary_info:
+            batch_complementary_info = {}
+            for key in self.complementary_info_keys:
+                batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
+
         return BatchTransition(
             state=batch_state,
             action=batch_actions,
@@ -342,6 +387,7 @@ class ReplayBuffer:
             next_state=batch_next_state,
             done=batch_dones,
             truncated=batch_truncateds,
+            complementary_info=batch_complementary_info,
         )
 
     def get_iterator(
@@ -518,7 +564,19 @@ class ReplayBuffer:
             if action_delta is not None:
                 first_action = first_action / action_delta
 
-            replay_buffer._initialize_storage(state=first_state, action=first_action)
+            # Get complementary info if available
+            first_complementary_info = None
+            if (
+                "complementary_info" in first_transition
+                and first_transition["complementary_info"] is not None
+            ):
+                first_complementary_info = {
+                    k: v.to(device) for k, v in first_transition["complementary_info"].items()
+                }
+
+            replay_buffer._initialize_storage(
+                state=first_state, action=first_action, complementary_info=first_complementary_info
+            )
 
         # Fill the buffer with all transitions
         for data in list_transition:
@@ -546,6 +604,7 @@ class ReplayBuffer:
                 next_state=data["next_state"],
                 done=data["done"],
                 truncated=False,  # NOTE: Truncation are not supported yet in lerobot dataset
+                complementary_info=data.get("complementary_info", None),
             )
 
         return replay_buffer
@@ -587,6 +646,13 @@ class ReplayBuffer:
             f_info = guess_feature_info(t=sample_val, name=key)
             features[key] = f_info
 
+        # Add complementary_info keys if available
+        if self.has_complementary_info:
+            for key in self.complementary_info_keys:
+                sample_val = self.complementary_info[key][0]
+                f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}")
+                features[f"complementary_info.{key}"] = f_info
+
         # Create an empty LeRobotDataset
         lerobot_dataset = LeRobotDataset.create(
             repo_id=repo_id,
@@ -620,6 +686,11 @@ class ReplayBuffer:
             frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
             frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
 
+            # Add complementary_info if available
+            if self.has_complementary_info:
+                for key in self.complementary_info_keys:
+                    frame_dict[f"complementary_info.{key}"] = self.complementary_info[key][actual_idx].cpu()
+
             # Add task field which is required by LeRobotDataset
             frame_dict["task"] = task_name
 
@@ -686,6 +757,10 @@ class ReplayBuffer:
         sample = dataset[0]
         has_done_key = "next.done" in sample
 
+        # Check for complementary_info keys
+        complementary_info_keys = [key for key in sample.keys() if key.startswith("complementary_info.")]
+        has_complementary_info = len(complementary_info_keys) > 0
+
         # If not, we need to infer it from episode boundaries
         if not has_done_key:
             print("'next.done' key not found in dataset. Inferring from episode boundaries...")
@@ -735,6 +810,16 @@ class ReplayBuffer:
                         next_state_data[key] = val.unsqueeze(0)  # Add batch dimension
                     next_state = next_state_data
 
+            # ----- 5) Complementary info (if available) -----
+            complementary_info = None
+            if has_complementary_info:
+                complementary_info = {}
+                for key in complementary_info_keys:
+                    # Strip the "complementary_info." prefix to get the actual key
+                    clean_key = key[len("complementary_info.") :]
+                    val = current_sample[key]
+                    complementary_info[clean_key] = val.unsqueeze(0)  # Add batch dimension
+
             # ----- Construct the Transition -----
             transition = Transition(
                 state=current_state,
@@ -743,6 +828,7 @@ class ReplayBuffer:
                 next_state=next_state,
                 done=done,
                 truncated=truncated,
+                complementary_info=complementary_info,
             )
             transitions.append(transition)
 
@@ -775,32 +861,33 @@ def concatenate_batch_transitions(
     left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
 ) -> BatchTransition:
     """NOTE: Be careful it change the left_batch_transitions in place"""
+    # Concatenate state fields
     left_batch_transitions["state"] = {
         key: torch.cat(
-            [
-                left_batch_transitions["state"][key],
-                right_batch_transition["state"][key],
-            ],
+            [left_batch_transitions["state"][key], right_batch_transition["state"][key]],
             dim=0,
         )
         for key in left_batch_transitions["state"]
     }
+
+    # Concatenate basic fields
     left_batch_transitions["action"] = torch.cat(
         [left_batch_transitions["action"], right_batch_transition["action"]], dim=0
     )
     left_batch_transitions["reward"] = torch.cat(
         [left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
     )
+
+    # Concatenate next_state fields
     left_batch_transitions["next_state"] = {
         key: torch.cat(
-            [
-                left_batch_transitions["next_state"][key],
-                right_batch_transition["next_state"][key],
-            ],
+            [left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]],
             dim=0,
         )
         for key in left_batch_transitions["next_state"]
     }
+
+    # Concatenate done and truncated fields
     left_batch_transitions["done"] = torch.cat(
         [left_batch_transitions["done"], right_batch_transition["done"]], dim=0
     )
@@ -808,8 +895,26 @@ def concatenate_batch_transitions(
         [left_batch_transitions["truncated"], right_batch_transition["truncated"]],
         dim=0,
     )
+
+    # Handle complementary_info
+    left_info = left_batch_transitions.get("complementary_info")
+    right_info = right_batch_transition.get("complementary_info")
+
+    # Only process if right_info exists
+    if right_info is not None:
+        # Initialize left complementary_info if needed
+        if left_info is None:
+            left_batch_transitions["complementary_info"] = right_info
+        else:
+            # Concatenate each field
+            for key in right_info:
+                if key in left_info:
+                    left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0)
+                else:
+                    left_info[key] = right_info[key]
+
     return left_batch_transitions
 
 
 if __name__ == "__main__":
-    pass  # All test code is currently commented out
+    pass
diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py
index 37586fe9..5489d6dc 100644
--- a/lerobot/scripts/server/learner_server.py
+++ b/lerobot/scripts/server/learner_server.py
@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+# !/usr/bin/env python
 
 # Copyright 2024 The HuggingFace Inc. team.
 # All rights reserved.