fix: send obs, receives and queues actions chunk, overwrites queue periodically

This commit is contained in:
Francesco Capuano 2025-04-15 12:00:33 +02:00
parent b9b7492132
commit 68927bc17d
1 changed files with 121 additions and 37 deletions
lerobot/scripts/server

View File

@ -3,6 +3,8 @@ import time
import threading import threading
import numpy as np import numpy as np
from concurrent import futures from concurrent import futures
from queue import Queue, Empty
from typing import Optional, Union
import async_inference_pb2 # type: ignore import async_inference_pb2 # type: ignore
import async_inference_pb2_grpc # type: ignore import async_inference_pb2_grpc # type: ignore
@ -11,19 +13,27 @@ class RobotClient:
def __init__(self, server_address="localhost:50051"): def __init__(self, server_address="localhost:50051"):
self.channel = grpc.insecure_channel(server_address) self.channel = grpc.insecure_channel(server_address)
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel) self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
self.running = False self.running = False
self.action_callback = None self.first_observation_sent = False
self.action_chunk_size = 10
self.action_queue = Queue()
self.action_queue_lock = threading.Lock()
# debugging purposes
self.action_buffer = []
def start(self): def start(self):
"""Start the robot client and connect to the policy server""" """Start the robot client and connect to the policy server"""
try: try:
# Check if the server is ready # client-server handshake
self.stub.Ready(async_inference_pb2.Empty()) self.stub.Ready(async_inference_pb2.Empty())
print("Connected to policy server server") print("Connected to policy server")
self.running = True
self.running = True
return True return True
except grpc.RpcError as e: except grpc.RpcError as e:
print(f"Failed to connect to policy server: {e}") print(f"Failed to connect to policy server: {e}")
return False return False
@ -33,77 +43,151 @@ class RobotClient:
self.running = False self.running = False
self.channel.close() self.channel.close()
def send_observation(self, observation_data, transfer_state=async_inference_pb2.TRANSFER_MIDDLE): def send_observation(
"""Send a single observation to the policy server""" self,
observation_data: Union[np.ndarray, bytes],
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE
) -> bool:
"""Send observation to the policy server.
Returns True if the observation was sent successfully, False otherwise."""
if not self.running: if not self.running:
print("Client not running") print("Client not running")
return False return False
# Convert observation data to bytes # Convert observation data to bytes
if not isinstance(observation_data, bytes): if not isinstance(observation_data, bytes):
observation_data = np.array(observation_data).tobytes() observation_data = np.array(observation_data).tobytes()
observation = async_inference_pb2.Observation( observation = async_inference_pb2.Observation(
transfer_state=transfer_state, transfer_state=transfer_state,
data=observation_data data=observation_data
) )
try: try:
# For a single observation _ = self.stub.SendObservations(iter([observation]))
response_future = self.stub.SendObservations(iter([observation])) if transfer_state == async_inference_pb2.TRANSFER_BEGIN:
self.first_observation_sent = True
return True return True
except grpc.RpcError as e: except grpc.RpcError as e:
print(f"Error sending observation: {e}") print(f"Error sending observation: {e}")
return False return False
def _should_replace_queue(self, percentage_left: float = 0.5) -> bool:
"""Check if we should replace the queue based on consumption rate"""
with self.action_queue_lock:
current_size = self.action_queue.qsize()
return current_size/self.action_chunk_size <= percentage_left
def _clear_and_refill_queue(self, actions: list[np.ndarray]):
"""Clear the existing queue and fill it with new actions"""
assert len(actions) == self.action_chunk_size, \
f"Action batch size must match action chunk!" \
f"size: {len(actions)} != {self.action_chunk_size}"
with self.action_queue_lock:
# Clear the queue
while not self.action_queue.empty():
try:
self.action_queue.get_nowait()
except Empty:
break
# Fill with new actions
for action in actions:
self.action_queue.put(action)
def receive_actions(self): def receive_actions(self):
"""Receive actions from the policy server""" """Receive actions from the policy server"""
while self.running: while self.running:
# Wait until first observation is sent
if not self.first_observation_sent:
time.sleep(0.1)
continue
try: try:
# Use StreamActions to get a stream of actions from the server # Use StreamActions to get a stream of actions from the server
action_batch = []
for action in self.stub.StreamActions(async_inference_pb2.Empty()): for action in self.stub.StreamActions(async_inference_pb2.Empty()):
action_data = np.frombuffer(action.data, dtype=np.float32) # NOTE: reading from buffer with numpy requires reshaping
print( action_data = np.frombuffer(
"Received action: ", action.data, dtype=np.float32
f"state={action.transfer_state}, ", ).reshape(self.action_chunk_size, -1)
f"data={action_data}, ",
f"data size={len(action.data)} bytes"
)
for a in action_data:
action_batch.append(a)
# Replace entire queue with new batch of actions
if action_batch and self._should_replace_queue():
self._clear_and_refill_queue(action_batch)
except grpc.RpcError as e: except grpc.RpcError as e:
print(f"Error receiving actions: {e}") print(f"Error receiving actions: {e}")
time.sleep(1) # Avoid tight loop on error time.sleep(1) # Avoid tight loop on error
def get_next_action(self) -> Optional[np.ndarray]:
"""Get the next action from the queue"""
try:
with self.action_queue_lock:
return self.action_queue.get_nowait()
except Empty:
return None
def stream_observations(self, get_observation_fn):
"""Continuously stream observations to the server"""
first_observation = True
while self.running:
try:
observation = get_observation_fn()
# Set appropriate transfer state
if first_observation:
state = async_inference_pb2.TRANSFER_BEGIN
first_observation = False
else:
state = async_inference_pb2.TRANSFER_MIDDLE
self.send_observation(observation, state)
time.sleep(0.1) # Adjust rate as needed
except Exception as e:
print(f"Error in observation sender: {e}")
time.sleep(1)
def example_usage(): def example_usage():
# Example of how to use the RobotClient # Example of how to use the RobotClient
client = RobotClient() client = RobotClient()
if client.start(): if client.start():
# Creating & starting a thread for receiving actions # Function to generate mock observations
def get_mock_observation():
return np.random.randint(0, 10, size=10).astype(np.float32)
# Create and start observation sender thread
obs_thread = threading.Thread(
target=client.stream_observations,
args=(get_mock_observation,)
)
obs_thread.daemon = True
obs_thread.start()
# Create and start action receiver thread
action_thread = threading.Thread(target=client.receive_actions) action_thread = threading.Thread(target=client.receive_actions)
action_thread.daemon = True action_thread.daemon = True
action_thread.start() action_thread.start()
try: try:
# Send observations to the server in the main thread # Main loop - action execution
for i in range(10):
observation = np.random.randint(0, 10, size=10).astype(np.float32)
if i == 0:
state = async_inference_pb2.TRANSFER_BEGIN
elif i == 9:
state = async_inference_pb2.TRANSFER_END
else:
state = async_inference_pb2.TRANSFER_MIDDLE
client.send_observation(observation, state)
time.sleep(1)
# Keep the main thread alive to continue receiving actions
while True: while True:
time.sleep(1) print(client.action_queue.qsize())
action = client.get_next_action()
if action is not None:
print(f"Executing action: {action}")
time.sleep(1)
else:
print("No action available")
time.sleep(0.5)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass