diff --git a/lerobot/scripts/server/robot_client.py b/lerobot/scripts/server/robot_client.py index 1970ac40..a535e7ea 100644 --- a/lerobot/scripts/server/robot_client.py +++ b/lerobot/scripts/server/robot_client.py @@ -3,6 +3,8 @@ import time import threading import numpy as np from concurrent import futures +from queue import Queue, Empty +from typing import Optional, Union import async_inference_pb2 # type: ignore import async_inference_pb2_grpc # type: ignore @@ -11,19 +13,27 @@ class RobotClient: def __init__(self, server_address="localhost:50051"): self.channel = grpc.insecure_channel(server_address) self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel) + self.running = False - self.action_callback = None + self.first_observation_sent = False + self.action_chunk_size = 10 + + self.action_queue = Queue() + self.action_queue_lock = threading.Lock() + + # debugging purposes + self.action_buffer = [] def start(self): """Start the robot client and connect to the policy server""" try: - # Check if the server is ready + # client-server handshake self.stub.Ready(async_inference_pb2.Empty()) - print("Connected to policy server server") - self.running = True + print("Connected to policy server") + self.running = True return True - + except grpc.RpcError as e: print(f"Failed to connect to policy server: {e}") return False @@ -33,77 +43,151 @@ class RobotClient: self.running = False self.channel.close() - def send_observation(self, observation_data, transfer_state=async_inference_pb2.TRANSFER_MIDDLE): - """Send a single observation to the policy server""" + def send_observation( + self, + observation_data: Union[np.ndarray, bytes], + transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE + ) -> bool: + """Send observation to the policy server. + Returns True if the observation was sent successfully, False otherwise.""" if not self.running: print("Client not running") return False - + # Convert observation data to bytes if not isinstance(observation_data, bytes): observation_data = np.array(observation_data).tobytes() - + observation = async_inference_pb2.Observation( transfer_state=transfer_state, data=observation_data ) - + try: - # For a single observation - response_future = self.stub.SendObservations(iter([observation])) + _ = self.stub.SendObservations(iter([observation])) + if transfer_state == async_inference_pb2.TRANSFER_BEGIN: + self.first_observation_sent = True return True + except grpc.RpcError as e: print(f"Error sending observation: {e}") return False + + def _should_replace_queue(self, percentage_left: float = 0.5) -> bool: + """Check if we should replace the queue based on consumption rate""" + with self.action_queue_lock: + current_size = self.action_queue.qsize() + return current_size/self.action_chunk_size <= percentage_left + + def _clear_and_refill_queue(self, actions: list[np.ndarray]): + """Clear the existing queue and fill it with new actions""" + assert len(actions) == self.action_chunk_size, \ + f"Action batch size must match action chunk!" \ + f"size: {len(actions)} != {self.action_chunk_size}" + + with self.action_queue_lock: + # Clear the queue + while not self.action_queue.empty(): + try: + self.action_queue.get_nowait() + except Empty: + break + + # Fill with new actions + for action in actions: + self.action_queue.put(action) + def receive_actions(self): """Receive actions from the policy server""" while self.running: + # Wait until first observation is sent + if not self.first_observation_sent: + time.sleep(0.1) + continue + try: # Use StreamActions to get a stream of actions from the server + action_batch = [] for action in self.stub.StreamActions(async_inference_pb2.Empty()): - action_data = np.frombuffer(action.data, dtype=np.float32) - print( - "Received action: ", - f"state={action.transfer_state}, ", - f"data={action_data}, ", - f"data size={len(action.data)} bytes" - ) + # NOTE: reading from buffer with numpy requires reshaping + action_data = np.frombuffer( + action.data, dtype=np.float32 + ).reshape(self.action_chunk_size, -1) + for a in action_data: + action_batch.append(a) + + # Replace entire queue with new batch of actions + if action_batch and self._should_replace_queue(): + self._clear_and_refill_queue(action_batch) + except grpc.RpcError as e: print(f"Error receiving actions: {e}") time.sleep(1) # Avoid tight loop on error + + def get_next_action(self) -> Optional[np.ndarray]: + """Get the next action from the queue""" + try: + with self.action_queue_lock: + return self.action_queue.get_nowait() + except Empty: + return None + + def stream_observations(self, get_observation_fn): + """Continuously stream observations to the server""" + first_observation = True + while self.running: + try: + observation = get_observation_fn() + + # Set appropriate transfer state + if first_observation: + state = async_inference_pb2.TRANSFER_BEGIN + first_observation = False + else: + state = async_inference_pb2.TRANSFER_MIDDLE + self.send_observation(observation, state) + time.sleep(0.1) # Adjust rate as needed + + except Exception as e: + print(f"Error in observation sender: {e}") + time.sleep(1) def example_usage(): # Example of how to use the RobotClient client = RobotClient() if client.start(): - # Creating & starting a thread for receiving actions + # Function to generate mock observations + def get_mock_observation(): + return np.random.randint(0, 10, size=10).astype(np.float32) + + # Create and start observation sender thread + obs_thread = threading.Thread( + target=client.stream_observations, + args=(get_mock_observation,) + ) + obs_thread.daemon = True + obs_thread.start() + + # Create and start action receiver thread action_thread = threading.Thread(target=client.receive_actions) action_thread.daemon = True action_thread.start() try: - # Send observations to the server in the main thread - for i in range(10): - observation = np.random.randint(0, 10, size=10).astype(np.float32) - - if i == 0: - state = async_inference_pb2.TRANSFER_BEGIN - elif i == 9: - state = async_inference_pb2.TRANSFER_END - else: - state = async_inference_pb2.TRANSFER_MIDDLE - - client.send_observation(observation, state) - time.sleep(1) - - - # Keep the main thread alive to continue receiving actions + # Main loop - action execution while True: - time.sleep(1) + print(client.action_queue.qsize()) + action = client.get_next_action() + if action is not None: + print(f"Executing action: {action}") + time.sleep(1) + else: + print("No action available") + time.sleep(0.5) except KeyboardInterrupt: pass