From be408edf1ce3ea7bbe7ffb0709385aa0a8049314 Mon Sep 17 00:00:00 2001 From: Francesco Capuano Date: Mon, 14 Apr 2025 17:29:21 +0200 Subject: [PATCH] add: robot can send observations --- lerobot/scripts/server/policy_server.py | 129 +++++++++++++++++++++++ lerobot/scripts/server/robot_client.py | 132 ++++++++++++++++++++++++ 2 files changed, 261 insertions(+) create mode 100644 lerobot/scripts/server/policy_server.py create mode 100644 lerobot/scripts/server/robot_client.py diff --git a/lerobot/scripts/server/policy_server.py b/lerobot/scripts/server/policy_server.py new file mode 100644 index 00000000..3c40360f --- /dev/null +++ b/lerobot/scripts/server/policy_server.py @@ -0,0 +1,129 @@ +import torch +import grpc +import time +import threading +import numpy as np +from concurrent import futures + +import async_inference_pb2 # type: ignore +import async_inference_pb2_grpc # type: ignore + +from lerobot.common.robot_devices.control_utils import predict_action +from lerobot.common.policies.pretrained import PreTrainedPolicy +from typing import Optional + +def get_device(): + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): + def __init__(self, policy: PreTrainedPolicy = None): + self.policy = policy + + # TODO: Add device specification for policy inference + + self.observation = None + self.clients = [] + self.lock = threading.Lock() + # keeping a list of all observations received from the robot client + self.observations = [] + + def Ready(self, request, context): + print("Client connected and ready") + return async_inference_pb2.Empty() + + def SendObservations(self, request_iterator, context): + """Receive observations from the robot client""" + client_id = context.peer() + print(f"Receiving observations from {client_id}") + + for observation in request_iterator: + print( + "Received observation: ", + f"state={observation.transfer_state}, " + f"data size={len(observation.data)} bytes" + ) + + + with self.lock: + self.observation = observation + self.observations.append(observation) + + data = np.frombuffer(self.observation.data, dtype=np.float32) + print(f"Current observation data: {data}") + + return async_inference_pb2.Empty() + + def StreamActions(self, request, context): + """Stream actions to the robot client""" + client_id = context.peer() + print(f"Client {client_id} connected for action streaming") + + # Keep track of this client for sending actions + with self.lock: + self.clients.append(context) + + try: + # Keep the connection alive + while context.is_active(): + time.sleep(0.1) + finally: + with self.lock: + if context in self.clients: + self.clients.remove(context) + + return async_inference_pb2.Empty() + + def _predict_and_queue_action(self, observation): + """Predict an action based on the observation""" + # TODO: Implement the logic to predict an action based on the observation + """ + Ideally, action-prediction should be general and not specific to the policy used. + That is, this interface should be the same for ACT/VLA/RL-based etc. + """ + # TODO: Queue the action to be sent to the robot client + raise NotImplementedError("Not implemented") + + def _generate_and_queue_action(self, observation): + """Generate an action based on the observation (dummy logic). + Mainly used for testing purposes""" + # Just create a random action as a response + action_data = np.random.rand(50).astype(np.float32).tobytes() + + action = async_inference_pb2.Action( + transfer_state=observation.transfer_state, + data=action_data + ) + + # Send this action to all connected clients + dead_clients = [] + for client_context in self.clients: + try: + if client_context.is_active(): + client_context.send_initial_metadata([]) + yield action + else: + dead_clients.append(client_context) + except: + dead_clients.append(client_context) + + # Clean up dead clients, if any + for dead in dead_clients: + if dead in self.clients: + self.clients.remove(dead) + +def serve(): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(PolicyServer(), server) + server.add_insecure_port('[::]:50051') + server.start() + print("PolicyServer started on port 50051") + + try: + while True: + time.sleep(86400) # Sleep for a day, or until interrupted + except KeyboardInterrupt: + server.stop(0) + print("Server stopped") + +if __name__ == "__main__": + serve() diff --git a/lerobot/scripts/server/robot_client.py b/lerobot/scripts/server/robot_client.py new file mode 100644 index 00000000..112051a4 --- /dev/null +++ b/lerobot/scripts/server/robot_client.py @@ -0,0 +1,132 @@ +import grpc +import time +import threading +import numpy as np +from concurrent import futures + +import async_inference_pb2 # type: ignore +import async_inference_pb2_grpc # type: ignore + +class RobotClient: + def __init__(self, server_address="localhost:50051"): + self.channel = grpc.insecure_channel(server_address) + self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel) + self.running = False + self.action_callback = None + + def start(self): + """Start the robot client and connect to the policy server""" + try: + # Check if the server is ready + self.stub.Ready(async_inference_pb2.Empty()) + print("Connected to policy server server") + self.running = True + + # Start action receiving thread + self.action_thread = threading.Thread(target=self.receive_actions) + self.action_thread.daemon = True + self.action_thread.start() + + return True + except grpc.RpcError as e: + print(f"Failed to connect to policy server: {e}") + return False + + def stop(self): + """Stop the robot client""" + self.running = False + self.channel.close() + + def send_observation(self, observation_data, transfer_state=async_inference_pb2.TRANSFER_MIDDLE): + """Send a single observation to the policy server""" + if not self.running: + print("Client not running") + return False + + # Convert observation data to bytes + if not isinstance(observation_data, bytes): + observation_data = np.array(observation_data).tobytes() + + observation = async_inference_pb2.Observation( + transfer_state=transfer_state, + data=observation_data + ) + + try: + # For a single observation + response_future = self.stub.SendObservations(iter([observation])) + return True + except grpc.RpcError as e: + print(f"Error sending observation: {e}") + return False + + + def receive_actions(self): + """Receive actions from the policy server""" + while self.running: + try: + # Use StreamActions to get a stream of actions from the server + for action in self.stub.StreamActions(async_inference_pb2.Empty()): + if self.action_callback: + # Convert bytes back to data (assuming numpy array) + action_data = np.frombuffer(action.data) + self.action_callback( + action_data, + action.transfer_state + ) + else: + print( + "Received action: ", + f"state={action.transfer_state}, ", + f"data size={len(action.data)} bytes" + ) + + except grpc.RpcError as e: + print(f"Error receiving actions: {e}") + time.sleep(1) # Avoid tight loop on error + + def register_action_callback(self, callback): + """Register a callback for when actions are received""" + self.action_callback = callback + + +def example_usage(): + # Example of how to use the RobotClient + client = RobotClient() + + if client.start(): + # Define a callback for received actions + def on_action(action_data, transfer_state): + print(f"Action received: state={transfer_state}, data={action_data[:10]}...") + + client.register_action_callback(on_action) + + # Send some example observations + for i in range(10): + # Create dummy observation data + observation = np.arange(10, dtype=np.float32) + + # Send it to the policy server + if i == 0: + state = async_inference_pb2.TRANSFER_BEGIN + elif i == 9: + state = async_inference_pb2.TRANSFER_END + else: + state = async_inference_pb2.TRANSFER_MIDDLE + + client.send_observation(observation, state) + print(f"Sent observation {i+1}/10") + time.sleep(0.5) + + # Keep the main thread alive to receive actions + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + pass + finally: + client.stop() + +if __name__ == "__main__": + example_usage() +