add: robot can send observations
This commit is contained in:
parent
188bcf48fd
commit
be408edf1c
|
@ -0,0 +1,129 @@
|
||||||
|
import torch
|
||||||
|
import grpc
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import numpy as np
|
||||||
|
from concurrent import futures
|
||||||
|
|
||||||
|
import async_inference_pb2 # type: ignore
|
||||||
|
import async_inference_pb2_grpc # type: ignore
|
||||||
|
|
||||||
|
from lerobot.common.robot_devices.control_utils import predict_action
|
||||||
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
def get_device():
|
||||||
|
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||||
|
def __init__(self, policy: PreTrainedPolicy = None):
|
||||||
|
self.policy = policy
|
||||||
|
|
||||||
|
# TODO: Add device specification for policy inference
|
||||||
|
|
||||||
|
self.observation = None
|
||||||
|
self.clients = []
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
# keeping a list of all observations received from the robot client
|
||||||
|
self.observations = []
|
||||||
|
|
||||||
|
def Ready(self, request, context):
|
||||||
|
print("Client connected and ready")
|
||||||
|
return async_inference_pb2.Empty()
|
||||||
|
|
||||||
|
def SendObservations(self, request_iterator, context):
|
||||||
|
"""Receive observations from the robot client"""
|
||||||
|
client_id = context.peer()
|
||||||
|
print(f"Receiving observations from {client_id}")
|
||||||
|
|
||||||
|
for observation in request_iterator:
|
||||||
|
print(
|
||||||
|
"Received observation: ",
|
||||||
|
f"state={observation.transfer_state}, "
|
||||||
|
f"data size={len(observation.data)} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
self.observation = observation
|
||||||
|
self.observations.append(observation)
|
||||||
|
|
||||||
|
data = np.frombuffer(self.observation.data, dtype=np.float32)
|
||||||
|
print(f"Current observation data: {data}")
|
||||||
|
|
||||||
|
return async_inference_pb2.Empty()
|
||||||
|
|
||||||
|
def StreamActions(self, request, context):
|
||||||
|
"""Stream actions to the robot client"""
|
||||||
|
client_id = context.peer()
|
||||||
|
print(f"Client {client_id} connected for action streaming")
|
||||||
|
|
||||||
|
# Keep track of this client for sending actions
|
||||||
|
with self.lock:
|
||||||
|
self.clients.append(context)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Keep the connection alive
|
||||||
|
while context.is_active():
|
||||||
|
time.sleep(0.1)
|
||||||
|
finally:
|
||||||
|
with self.lock:
|
||||||
|
if context in self.clients:
|
||||||
|
self.clients.remove(context)
|
||||||
|
|
||||||
|
return async_inference_pb2.Empty()
|
||||||
|
|
||||||
|
def _predict_and_queue_action(self, observation):
|
||||||
|
"""Predict an action based on the observation"""
|
||||||
|
# TODO: Implement the logic to predict an action based on the observation
|
||||||
|
"""
|
||||||
|
Ideally, action-prediction should be general and not specific to the policy used.
|
||||||
|
That is, this interface should be the same for ACT/VLA/RL-based etc.
|
||||||
|
"""
|
||||||
|
# TODO: Queue the action to be sent to the robot client
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
def _generate_and_queue_action(self, observation):
|
||||||
|
"""Generate an action based on the observation (dummy logic).
|
||||||
|
Mainly used for testing purposes"""
|
||||||
|
# Just create a random action as a response
|
||||||
|
action_data = np.random.rand(50).astype(np.float32).tobytes()
|
||||||
|
|
||||||
|
action = async_inference_pb2.Action(
|
||||||
|
transfer_state=observation.transfer_state,
|
||||||
|
data=action_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send this action to all connected clients
|
||||||
|
dead_clients = []
|
||||||
|
for client_context in self.clients:
|
||||||
|
try:
|
||||||
|
if client_context.is_active():
|
||||||
|
client_context.send_initial_metadata([])
|
||||||
|
yield action
|
||||||
|
else:
|
||||||
|
dead_clients.append(client_context)
|
||||||
|
except:
|
||||||
|
dead_clients.append(client_context)
|
||||||
|
|
||||||
|
# Clean up dead clients, if any
|
||||||
|
for dead in dead_clients:
|
||||||
|
if dead in self.clients:
|
||||||
|
self.clients.remove(dead)
|
||||||
|
|
||||||
|
def serve():
|
||||||
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||||
|
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(PolicyServer(), server)
|
||||||
|
server.add_insecure_port('[::]:50051')
|
||||||
|
server.start()
|
||||||
|
print("PolicyServer started on port 50051")
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(86400) # Sleep for a day, or until interrupted
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
server.stop(0)
|
||||||
|
print("Server stopped")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
serve()
|
|
@ -0,0 +1,132 @@
|
||||||
|
import grpc
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import numpy as np
|
||||||
|
from concurrent import futures
|
||||||
|
|
||||||
|
import async_inference_pb2 # type: ignore
|
||||||
|
import async_inference_pb2_grpc # type: ignore
|
||||||
|
|
||||||
|
class RobotClient:
|
||||||
|
def __init__(self, server_address="localhost:50051"):
|
||||||
|
self.channel = grpc.insecure_channel(server_address)
|
||||||
|
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
|
||||||
|
self.running = False
|
||||||
|
self.action_callback = None
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Start the robot client and connect to the policy server"""
|
||||||
|
try:
|
||||||
|
# Check if the server is ready
|
||||||
|
self.stub.Ready(async_inference_pb2.Empty())
|
||||||
|
print("Connected to policy server server")
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
# Start action receiving thread
|
||||||
|
self.action_thread = threading.Thread(target=self.receive_actions)
|
||||||
|
self.action_thread.daemon = True
|
||||||
|
self.action_thread.start()
|
||||||
|
|
||||||
|
return True
|
||||||
|
except grpc.RpcError as e:
|
||||||
|
print(f"Failed to connect to policy server: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stop the robot client"""
|
||||||
|
self.running = False
|
||||||
|
self.channel.close()
|
||||||
|
|
||||||
|
def send_observation(self, observation_data, transfer_state=async_inference_pb2.TRANSFER_MIDDLE):
|
||||||
|
"""Send a single observation to the policy server"""
|
||||||
|
if not self.running:
|
||||||
|
print("Client not running")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Convert observation data to bytes
|
||||||
|
if not isinstance(observation_data, bytes):
|
||||||
|
observation_data = np.array(observation_data).tobytes()
|
||||||
|
|
||||||
|
observation = async_inference_pb2.Observation(
|
||||||
|
transfer_state=transfer_state,
|
||||||
|
data=observation_data
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# For a single observation
|
||||||
|
response_future = self.stub.SendObservations(iter([observation]))
|
||||||
|
return True
|
||||||
|
except grpc.RpcError as e:
|
||||||
|
print(f"Error sending observation: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def receive_actions(self):
|
||||||
|
"""Receive actions from the policy server"""
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
# Use StreamActions to get a stream of actions from the server
|
||||||
|
for action in self.stub.StreamActions(async_inference_pb2.Empty()):
|
||||||
|
if self.action_callback:
|
||||||
|
# Convert bytes back to data (assuming numpy array)
|
||||||
|
action_data = np.frombuffer(action.data)
|
||||||
|
self.action_callback(
|
||||||
|
action_data,
|
||||||
|
action.transfer_state
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"Received action: ",
|
||||||
|
f"state={action.transfer_state}, ",
|
||||||
|
f"data size={len(action.data)} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
except grpc.RpcError as e:
|
||||||
|
print(f"Error receiving actions: {e}")
|
||||||
|
time.sleep(1) # Avoid tight loop on error
|
||||||
|
|
||||||
|
def register_action_callback(self, callback):
|
||||||
|
"""Register a callback for when actions are received"""
|
||||||
|
self.action_callback = callback
|
||||||
|
|
||||||
|
|
||||||
|
def example_usage():
|
||||||
|
# Example of how to use the RobotClient
|
||||||
|
client = RobotClient()
|
||||||
|
|
||||||
|
if client.start():
|
||||||
|
# Define a callback for received actions
|
||||||
|
def on_action(action_data, transfer_state):
|
||||||
|
print(f"Action received: state={transfer_state}, data={action_data[:10]}...")
|
||||||
|
|
||||||
|
client.register_action_callback(on_action)
|
||||||
|
|
||||||
|
# Send some example observations
|
||||||
|
for i in range(10):
|
||||||
|
# Create dummy observation data
|
||||||
|
observation = np.arange(10, dtype=np.float32)
|
||||||
|
|
||||||
|
# Send it to the policy server
|
||||||
|
if i == 0:
|
||||||
|
state = async_inference_pb2.TRANSFER_BEGIN
|
||||||
|
elif i == 9:
|
||||||
|
state = async_inference_pb2.TRANSFER_END
|
||||||
|
else:
|
||||||
|
state = async_inference_pb2.TRANSFER_MIDDLE
|
||||||
|
|
||||||
|
client.send_observation(observation, state)
|
||||||
|
print(f"Sent observation {i+1}/10")
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
# Keep the main thread alive to receive actions
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
client.stop()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
example_usage()
|
||||||
|
|
Loading…
Reference in New Issue