fix: send obs, receives and queues actions chunk, overwrites queue periodically
This commit is contained in:
parent
b9b7492132
commit
68927bc17d
lerobot/scripts/server
|
@ -3,6 +3,8 @@ import time
|
||||||
import threading
|
import threading
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
|
from queue import Queue, Empty
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
import async_inference_pb2 # type: ignore
|
import async_inference_pb2 # type: ignore
|
||||||
import async_inference_pb2_grpc # type: ignore
|
import async_inference_pb2_grpc # type: ignore
|
||||||
|
@ -11,19 +13,27 @@ class RobotClient:
|
||||||
def __init__(self, server_address="localhost:50051"):
|
def __init__(self, server_address="localhost:50051"):
|
||||||
self.channel = grpc.insecure_channel(server_address)
|
self.channel = grpc.insecure_channel(server_address)
|
||||||
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
|
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
|
||||||
|
|
||||||
self.running = False
|
self.running = False
|
||||||
self.action_callback = None
|
self.first_observation_sent = False
|
||||||
|
self.action_chunk_size = 10
|
||||||
|
|
||||||
|
self.action_queue = Queue()
|
||||||
|
self.action_queue_lock = threading.Lock()
|
||||||
|
|
||||||
|
# debugging purposes
|
||||||
|
self.action_buffer = []
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Start the robot client and connect to the policy server"""
|
"""Start the robot client and connect to the policy server"""
|
||||||
try:
|
try:
|
||||||
# Check if the server is ready
|
# client-server handshake
|
||||||
self.stub.Ready(async_inference_pb2.Empty())
|
self.stub.Ready(async_inference_pb2.Empty())
|
||||||
print("Connected to policy server server")
|
print("Connected to policy server")
|
||||||
self.running = True
|
|
||||||
|
|
||||||
|
self.running = True
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except grpc.RpcError as e:
|
except grpc.RpcError as e:
|
||||||
print(f"Failed to connect to policy server: {e}")
|
print(f"Failed to connect to policy server: {e}")
|
||||||
return False
|
return False
|
||||||
|
@ -33,77 +43,151 @@ class RobotClient:
|
||||||
self.running = False
|
self.running = False
|
||||||
self.channel.close()
|
self.channel.close()
|
||||||
|
|
||||||
def send_observation(self, observation_data, transfer_state=async_inference_pb2.TRANSFER_MIDDLE):
|
def send_observation(
|
||||||
"""Send a single observation to the policy server"""
|
self,
|
||||||
|
observation_data: Union[np.ndarray, bytes],
|
||||||
|
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE
|
||||||
|
) -> bool:
|
||||||
|
"""Send observation to the policy server.
|
||||||
|
Returns True if the observation was sent successfully, False otherwise."""
|
||||||
if not self.running:
|
if not self.running:
|
||||||
print("Client not running")
|
print("Client not running")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Convert observation data to bytes
|
# Convert observation data to bytes
|
||||||
if not isinstance(observation_data, bytes):
|
if not isinstance(observation_data, bytes):
|
||||||
observation_data = np.array(observation_data).tobytes()
|
observation_data = np.array(observation_data).tobytes()
|
||||||
|
|
||||||
observation = async_inference_pb2.Observation(
|
observation = async_inference_pb2.Observation(
|
||||||
transfer_state=transfer_state,
|
transfer_state=transfer_state,
|
||||||
data=observation_data
|
data=observation_data
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# For a single observation
|
_ = self.stub.SendObservations(iter([observation]))
|
||||||
response_future = self.stub.SendObservations(iter([observation]))
|
if transfer_state == async_inference_pb2.TRANSFER_BEGIN:
|
||||||
|
self.first_observation_sent = True
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except grpc.RpcError as e:
|
except grpc.RpcError as e:
|
||||||
print(f"Error sending observation: {e}")
|
print(f"Error sending observation: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _should_replace_queue(self, percentage_left: float = 0.5) -> bool:
|
||||||
|
"""Check if we should replace the queue based on consumption rate"""
|
||||||
|
with self.action_queue_lock:
|
||||||
|
current_size = self.action_queue.qsize()
|
||||||
|
return current_size/self.action_chunk_size <= percentage_left
|
||||||
|
|
||||||
|
def _clear_and_refill_queue(self, actions: list[np.ndarray]):
|
||||||
|
"""Clear the existing queue and fill it with new actions"""
|
||||||
|
assert len(actions) == self.action_chunk_size, \
|
||||||
|
f"Action batch size must match action chunk!" \
|
||||||
|
f"size: {len(actions)} != {self.action_chunk_size}"
|
||||||
|
|
||||||
|
with self.action_queue_lock:
|
||||||
|
# Clear the queue
|
||||||
|
while not self.action_queue.empty():
|
||||||
|
try:
|
||||||
|
self.action_queue.get_nowait()
|
||||||
|
except Empty:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Fill with new actions
|
||||||
|
for action in actions:
|
||||||
|
self.action_queue.put(action)
|
||||||
|
|
||||||
|
|
||||||
def receive_actions(self):
|
def receive_actions(self):
|
||||||
"""Receive actions from the policy server"""
|
"""Receive actions from the policy server"""
|
||||||
while self.running:
|
while self.running:
|
||||||
|
# Wait until first observation is sent
|
||||||
|
if not self.first_observation_sent:
|
||||||
|
time.sleep(0.1)
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use StreamActions to get a stream of actions from the server
|
# Use StreamActions to get a stream of actions from the server
|
||||||
|
action_batch = []
|
||||||
for action in self.stub.StreamActions(async_inference_pb2.Empty()):
|
for action in self.stub.StreamActions(async_inference_pb2.Empty()):
|
||||||
action_data = np.frombuffer(action.data, dtype=np.float32)
|
# NOTE: reading from buffer with numpy requires reshaping
|
||||||
print(
|
action_data = np.frombuffer(
|
||||||
"Received action: ",
|
action.data, dtype=np.float32
|
||||||
f"state={action.transfer_state}, ",
|
).reshape(self.action_chunk_size, -1)
|
||||||
f"data={action_data}, ",
|
|
||||||
f"data size={len(action.data)} bytes"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
for a in action_data:
|
||||||
|
action_batch.append(a)
|
||||||
|
|
||||||
|
# Replace entire queue with new batch of actions
|
||||||
|
if action_batch and self._should_replace_queue():
|
||||||
|
self._clear_and_refill_queue(action_batch)
|
||||||
|
|
||||||
except grpc.RpcError as e:
|
except grpc.RpcError as e:
|
||||||
print(f"Error receiving actions: {e}")
|
print(f"Error receiving actions: {e}")
|
||||||
time.sleep(1) # Avoid tight loop on error
|
time.sleep(1) # Avoid tight loop on error
|
||||||
|
|
||||||
|
def get_next_action(self) -> Optional[np.ndarray]:
|
||||||
|
"""Get the next action from the queue"""
|
||||||
|
try:
|
||||||
|
with self.action_queue_lock:
|
||||||
|
return self.action_queue.get_nowait()
|
||||||
|
except Empty:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def stream_observations(self, get_observation_fn):
|
||||||
|
"""Continuously stream observations to the server"""
|
||||||
|
first_observation = True
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
observation = get_observation_fn()
|
||||||
|
|
||||||
|
# Set appropriate transfer state
|
||||||
|
if first_observation:
|
||||||
|
state = async_inference_pb2.TRANSFER_BEGIN
|
||||||
|
first_observation = False
|
||||||
|
else:
|
||||||
|
state = async_inference_pb2.TRANSFER_MIDDLE
|
||||||
|
|
||||||
|
self.send_observation(observation, state)
|
||||||
|
time.sleep(0.1) # Adjust rate as needed
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in observation sender: {e}")
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
def example_usage():
|
def example_usage():
|
||||||
# Example of how to use the RobotClient
|
# Example of how to use the RobotClient
|
||||||
client = RobotClient()
|
client = RobotClient()
|
||||||
|
|
||||||
if client.start():
|
if client.start():
|
||||||
# Creating & starting a thread for receiving actions
|
# Function to generate mock observations
|
||||||
|
def get_mock_observation():
|
||||||
|
return np.random.randint(0, 10, size=10).astype(np.float32)
|
||||||
|
|
||||||
|
# Create and start observation sender thread
|
||||||
|
obs_thread = threading.Thread(
|
||||||
|
target=client.stream_observations,
|
||||||
|
args=(get_mock_observation,)
|
||||||
|
)
|
||||||
|
obs_thread.daemon = True
|
||||||
|
obs_thread.start()
|
||||||
|
|
||||||
|
# Create and start action receiver thread
|
||||||
action_thread = threading.Thread(target=client.receive_actions)
|
action_thread = threading.Thread(target=client.receive_actions)
|
||||||
action_thread.daemon = True
|
action_thread.daemon = True
|
||||||
action_thread.start()
|
action_thread.start()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Send observations to the server in the main thread
|
# Main loop - action execution
|
||||||
for i in range(10):
|
|
||||||
observation = np.random.randint(0, 10, size=10).astype(np.float32)
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
state = async_inference_pb2.TRANSFER_BEGIN
|
|
||||||
elif i == 9:
|
|
||||||
state = async_inference_pb2.TRANSFER_END
|
|
||||||
else:
|
|
||||||
state = async_inference_pb2.TRANSFER_MIDDLE
|
|
||||||
|
|
||||||
client.send_observation(observation, state)
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
|
|
||||||
# Keep the main thread alive to continue receiving actions
|
|
||||||
while True:
|
while True:
|
||||||
time.sleep(1)
|
print(client.action_queue.qsize())
|
||||||
|
action = client.get_next_action()
|
||||||
|
if action is not None:
|
||||||
|
print(f"Executing action: {action}")
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
print("No action available")
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue