add: server computes action, robot's daemon constantly reads it
This commit is contained in:
parent
be408edf1c
commit
485d64c8f4
|
@ -20,9 +20,12 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
|||
self.policy = policy
|
||||
|
||||
# TODO: Add device specification for policy inference
|
||||
|
||||
self.observation = None
|
||||
self.clients = []
|
||||
# self.observation = None
|
||||
self.observation = async_inference_pb2.Observation(
|
||||
transfer_state=2,
|
||||
data=np.array([1], dtype=np.float32).tobytes()
|
||||
)
|
||||
|
||||
self.lock = threading.Lock()
|
||||
# keeping a list of all observations received from the robot client
|
||||
self.observations = []
|
||||
|
@ -43,12 +46,15 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
|||
f"data size={len(observation.data)} bytes"
|
||||
)
|
||||
|
||||
|
||||
with self.lock:
|
||||
self.observation = observation
|
||||
self.observations.append(observation)
|
||||
|
||||
data = np.frombuffer(self.observation.data, dtype=np.float32)
|
||||
data = np.frombuffer(
|
||||
self.observation.data,
|
||||
# observation data are stored as float32
|
||||
dtype=np.float32
|
||||
)
|
||||
print(f"Current observation data: {data}")
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
@ -58,18 +64,8 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
|||
client_id = context.peer()
|
||||
print(f"Client {client_id} connected for action streaming")
|
||||
|
||||
# Keep track of this client for sending actions
|
||||
with self.lock:
|
||||
self.clients.append(context)
|
||||
|
||||
try:
|
||||
# Keep the connection alive
|
||||
while context.is_active():
|
||||
time.sleep(0.1)
|
||||
finally:
|
||||
with self.lock:
|
||||
if context in self.clients:
|
||||
self.clients.remove(context)
|
||||
yield self._generate_and_queue_action(self.observation)
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
|
@ -86,30 +82,22 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
|||
def _generate_and_queue_action(self, observation):
|
||||
"""Generate an action based on the observation (dummy logic).
|
||||
Mainly used for testing purposes"""
|
||||
# Just create a random action as a response
|
||||
action_data = np.random.rand(50).astype(np.float32).tobytes()
|
||||
|
||||
# Debinarize the observation data
|
||||
data = np.frombuffer(
|
||||
observation.data,
|
||||
dtype=np.float32
|
||||
)
|
||||
# dummy transform on the observation data
|
||||
action = (data * 1.4).sum()
|
||||
# map action to bytes
|
||||
action_data = np.array([action], dtype=np.float32).tobytes()
|
||||
|
||||
action = async_inference_pb2.Action(
|
||||
transfer_state=observation.transfer_state,
|
||||
data=action_data
|
||||
)
|
||||
|
||||
# Send this action to all connected clients
|
||||
dead_clients = []
|
||||
for client_context in self.clients:
|
||||
try:
|
||||
if client_context.is_active():
|
||||
client_context.send_initial_metadata([])
|
||||
yield action
|
||||
else:
|
||||
dead_clients.append(client_context)
|
||||
except:
|
||||
dead_clients.append(client_context)
|
||||
|
||||
# Clean up dead clients, if any
|
||||
for dead in dead_clients:
|
||||
if dead in self.clients:
|
||||
self.clients.remove(dead)
|
||||
|
||||
return action
|
||||
|
||||
def serve():
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
|
|
|
@ -22,12 +22,8 @@ class RobotClient:
|
|||
print("Connected to policy server server")
|
||||
self.running = True
|
||||
|
||||
# Start action receiving thread
|
||||
self.action_thread = threading.Thread(target=self.receive_actions)
|
||||
self.action_thread.daemon = True
|
||||
self.action_thread.start()
|
||||
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
print(f"Failed to connect to policy server: {e}")
|
||||
return False
|
||||
|
@ -59,71 +55,59 @@ class RobotClient:
|
|||
except grpc.RpcError as e:
|
||||
print(f"Error sending observation: {e}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def receive_actions(self):
|
||||
"""Receive actions from the policy server"""
|
||||
while self.running:
|
||||
try:
|
||||
# Use StreamActions to get a stream of actions from the server
|
||||
for action in self.stub.StreamActions(async_inference_pb2.Empty()):
|
||||
if self.action_callback:
|
||||
# Convert bytes back to data (assuming numpy array)
|
||||
action_data = np.frombuffer(action.data)
|
||||
self.action_callback(
|
||||
action_data,
|
||||
action.transfer_state
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"Received action: ",
|
||||
f"state={action.transfer_state}, ",
|
||||
f"data size={len(action.data)} bytes"
|
||||
)
|
||||
|
||||
action_data = np.frombuffer(action.data, dtype=np.float32)
|
||||
print(
|
||||
"Received action: ",
|
||||
f"state={action.transfer_state}, ",
|
||||
f"data={action_data}, ",
|
||||
f"data size={len(action.data)} bytes"
|
||||
)
|
||||
|
||||
except grpc.RpcError as e:
|
||||
print(f"Error receiving actions: {e}")
|
||||
time.sleep(1) # Avoid tight loop on error
|
||||
|
||||
def register_action_callback(self, callback):
|
||||
"""Register a callback for when actions are received"""
|
||||
self.action_callback = callback
|
||||
|
||||
|
||||
|
||||
def example_usage():
|
||||
# Example of how to use the RobotClient
|
||||
client = RobotClient()
|
||||
|
||||
if client.start():
|
||||
# Define a callback for received actions
|
||||
def on_action(action_data, transfer_state):
|
||||
print(f"Action received: state={transfer_state}, data={action_data[:10]}...")
|
||||
# Creating & starting a thread for receiving actions
|
||||
action_thread = threading.Thread(target=client.receive_actions)
|
||||
action_thread.daemon = True
|
||||
action_thread.start()
|
||||
|
||||
client.register_action_callback(on_action)
|
||||
|
||||
# Send some example observations
|
||||
for i in range(10):
|
||||
# Create dummy observation data
|
||||
observation = np.arange(10, dtype=np.float32)
|
||||
|
||||
# Send it to the policy server
|
||||
if i == 0:
|
||||
state = async_inference_pb2.TRANSFER_BEGIN
|
||||
elif i == 9:
|
||||
state = async_inference_pb2.TRANSFER_END
|
||||
else:
|
||||
state = async_inference_pb2.TRANSFER_MIDDLE
|
||||
|
||||
client.send_observation(observation, state)
|
||||
print(f"Sent observation {i+1}/10")
|
||||
time.sleep(0.5)
|
||||
|
||||
# Keep the main thread alive to receive actions
|
||||
try:
|
||||
# Send observations to the server in the main thread
|
||||
for i in range(10):
|
||||
observation = np.random.randint(0, 10, size=10).astype(np.float32)
|
||||
|
||||
if i == 0:
|
||||
state = async_inference_pb2.TRANSFER_BEGIN
|
||||
elif i == 9:
|
||||
state = async_inference_pb2.TRANSFER_END
|
||||
else:
|
||||
state = async_inference_pb2.TRANSFER_MIDDLE
|
||||
|
||||
client.send_observation(observation, state)
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
# Keep the main thread alive to continue receiving actions
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
finally:
|
||||
client.stop()
|
||||
|
||||
|
|
Loading…
Reference in New Issue