diff --git a/lerobot/scripts/server/policy_server.py b/lerobot/scripts/server/policy_server.py index 3c40360f..b2eb00fa 100644 --- a/lerobot/scripts/server/policy_server.py +++ b/lerobot/scripts/server/policy_server.py @@ -20,9 +20,12 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): self.policy = policy # TODO: Add device specification for policy inference - - self.observation = None - self.clients = [] + # self.observation = None + self.observation = async_inference_pb2.Observation( + transfer_state=2, + data=np.array([1], dtype=np.float32).tobytes() + ) + self.lock = threading.Lock() # keeping a list of all observations received from the robot client self.observations = [] @@ -43,12 +46,15 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): f"data size={len(observation.data)} bytes" ) - with self.lock: self.observation = observation self.observations.append(observation) - data = np.frombuffer(self.observation.data, dtype=np.float32) + data = np.frombuffer( + self.observation.data, + # observation data are stored as float32 + dtype=np.float32 + ) print(f"Current observation data: {data}") return async_inference_pb2.Empty() @@ -58,18 +64,8 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): client_id = context.peer() print(f"Client {client_id} connected for action streaming") - # Keep track of this client for sending actions with self.lock: - self.clients.append(context) - - try: - # Keep the connection alive - while context.is_active(): - time.sleep(0.1) - finally: - with self.lock: - if context in self.clients: - self.clients.remove(context) + yield self._generate_and_queue_action(self.observation) return async_inference_pb2.Empty() @@ -86,30 +82,22 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): def _generate_and_queue_action(self, observation): """Generate an action based on the observation (dummy logic). Mainly used for testing purposes""" - # Just create a random action as a response - action_data = np.random.rand(50).astype(np.float32).tobytes() - + # Debinarize the observation data + data = np.frombuffer( + observation.data, + dtype=np.float32 + ) + # dummy transform on the observation data + action = (data * 1.4).sum() + # map action to bytes + action_data = np.array([action], dtype=np.float32).tobytes() + action = async_inference_pb2.Action( transfer_state=observation.transfer_state, data=action_data ) - - # Send this action to all connected clients - dead_clients = [] - for client_context in self.clients: - try: - if client_context.is_active(): - client_context.send_initial_metadata([]) - yield action - else: - dead_clients.append(client_context) - except: - dead_clients.append(client_context) - - # Clean up dead clients, if any - for dead in dead_clients: - if dead in self.clients: - self.clients.remove(dead) + + return action def serve(): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) diff --git a/lerobot/scripts/server/robot_client.py b/lerobot/scripts/server/robot_client.py index 112051a4..1970ac40 100644 --- a/lerobot/scripts/server/robot_client.py +++ b/lerobot/scripts/server/robot_client.py @@ -22,12 +22,8 @@ class RobotClient: print("Connected to policy server server") self.running = True - # Start action receiving thread - self.action_thread = threading.Thread(target=self.receive_actions) - self.action_thread.daemon = True - self.action_thread.start() - return True + except grpc.RpcError as e: print(f"Failed to connect to policy server: {e}") return False @@ -59,71 +55,59 @@ class RobotClient: except grpc.RpcError as e: print(f"Error sending observation: {e}") return False - - + def receive_actions(self): """Receive actions from the policy server""" while self.running: try: # Use StreamActions to get a stream of actions from the server for action in self.stub.StreamActions(async_inference_pb2.Empty()): - if self.action_callback: - # Convert bytes back to data (assuming numpy array) - action_data = np.frombuffer(action.data) - self.action_callback( - action_data, - action.transfer_state - ) - else: - print( - "Received action: ", - f"state={action.transfer_state}, ", - f"data size={len(action.data)} bytes" - ) - + action_data = np.frombuffer(action.data, dtype=np.float32) + print( + "Received action: ", + f"state={action.transfer_state}, ", + f"data={action_data}, ", + f"data size={len(action.data)} bytes" + ) + except grpc.RpcError as e: print(f"Error receiving actions: {e}") time.sleep(1) # Avoid tight loop on error - - def register_action_callback(self, callback): - """Register a callback for when actions are received""" - self.action_callback = callback - + def example_usage(): # Example of how to use the RobotClient client = RobotClient() if client.start(): - # Define a callback for received actions - def on_action(action_data, transfer_state): - print(f"Action received: state={transfer_state}, data={action_data[:10]}...") + # Creating & starting a thread for receiving actions + action_thread = threading.Thread(target=client.receive_actions) + action_thread.daemon = True + action_thread.start() - client.register_action_callback(on_action) - - # Send some example observations - for i in range(10): - # Create dummy observation data - observation = np.arange(10, dtype=np.float32) - - # Send it to the policy server - if i == 0: - state = async_inference_pb2.TRANSFER_BEGIN - elif i == 9: - state = async_inference_pb2.TRANSFER_END - else: - state = async_inference_pb2.TRANSFER_MIDDLE - - client.send_observation(observation, state) - print(f"Sent observation {i+1}/10") - time.sleep(0.5) - - # Keep the main thread alive to receive actions try: + # Send observations to the server in the main thread + for i in range(10): + observation = np.random.randint(0, 10, size=10).astype(np.float32) + + if i == 0: + state = async_inference_pb2.TRANSFER_BEGIN + elif i == 9: + state = async_inference_pb2.TRANSFER_END + else: + state = async_inference_pb2.TRANSFER_MIDDLE + + client.send_observation(observation, state) + time.sleep(1) + + + # Keep the main thread alive to continue receiving actions while True: time.sleep(1) + except KeyboardInterrupt: pass + finally: client.stop()